"""
@author: joep
"""

import numpy as np
from scipy.integrate import odeint
from scipy import interpolate
import matplotlib.pyplot as plt

# The core gillespie function
# Performs a single Gillespie simulation
#
# n_s0 contains the initial particle counts
# N indicates the stoichiometry
# pars are the parameters
# fun is a function handle that takes the current
# state and parameter and returns the propensities
# maxtime is the maximum simulation time
# maxsteps is the maximal number of simulation steps
def gillespie(n_s0, N, fun, pars, maxtime, maxsteps):

    i = 1
    t = 0

    # Pre-allocate everything for speed
    n_rxns, n_states = N.shape
    n_S = np.zeros([n_states,maxsteps])
    times = np.zeros(maxsteps)

    # Apply initials
    n_current = n_s0
    n_S[:,0] = n_current
    while ( t < maxtime ) and ( i < maxsteps ):
        # Compute non-normalized propensities
        w = fun(n_current, pars)

        # Compute cumulative distribution
        p = np.cumsum(w)
        psum = p[-1]
        
        # No more reactions occur
        if ( psum == 0 ):
            times[i] = maxtime
            n_S[:,i] = n_current
            i = i + 1;
            break;
            
        # Normalize cumulative propensity to [0,1]
        p = p / psum
        
        # Calculate waiting time
        dt = -np.log( np.random.random_sample() ) / psum
        t = t + dt

        # Find which rxn to perform
        rnd = np.random.random_sample()
        rx = 0

        # Find a rxn to simulate. Rxns with zero propensity cannot occur.
        while( (p[rx] < rnd) or (w[rx]==0) ):
            rx = rx + 1

        # Apply changes due to selected rxn
        n_current = n_current + N[rx,:]
        
        n_S[:,i] = n_current
        times[i] = t
        i = i + 1

    # Make sure that we only return the points actually simulated
    # since we overallocated space
    times = times[0:i]
    n_S = n_S[:,0:i]

    return i, n_S, times


def averagedGillespie( M, initial, N, fun, pars, t, maxsteps ):
    nsteps = t.size
    means = np.zeros([4,nsteps])
    sds = np.zeros([4,nsteps])

    n_rxns, n_states = N.shape
    for jsim in range(1,M):
        # Run gillepsie simulation
        i, n_S, tsim = gillespie(initial, N, fun, pars, t[-1], maxsteps)

        # Interpolate trajectory to the same time points such that we can average them
        for js in range(0,n_states):
            f = interpolate.interp1d(tsim, n_S[js,:], kind='zero', bounds_error=False )
            sim = f(t)

            means[js,:] = means[js,:] + sim
            sds[js,:] = sds[js,:] + sim*sim

    # Variance is E(X^2)-E(X)^2
    means = means / M
    sds = ( sds / M - means*means )**.5

    return means, sds

# Factor between number of particles and concentration
N_to_conc = 1.0/1.0e-12/1.0e-9/6.022/1.0e23;

##############################
# MODEL 1
##############################
def model1():
    # Initial condition
    initial = np.array([ 10.0, 10.0, 0.0, 0.0 ])
    
    # Stoichiometry (reactions)
    # negative numbers indicate consumed species
    # positive numbers indicate produced species
    n_states = 4
    n_rxns   = 2
    N = np.zeros([n_rxns,n_states])
    N[0,:] = [ -1.0, -1.0,  1.0, 0.0 ]
    N[1,:] = [  0.0,  1.0, -1.0, 1.0 ]

    # Rates Gillespie
    parsGillespie = [ 0.5 / (1e-9 * 1e-12 * 6.022e23), 0.02 ]
    
    # Rates ODE
    # Note that k1 has units of inverse concentration and therefore also needs 
    # to be converted between particles and concentration.
    parsODE = [ 0.5 / (1e-9 * 1e-12 * 6.022e23) / N_to_conc, 0.02 ]

    propensityFunction = lambda x,p:[p[0]*x[0]*x[1], p[1]*x[2]]
    
    # Desired time vector
    maxsteps = 20000
    maxtime  = 1000
    
    return n_states, N, propensityFunction, initial, parsGillespie, parsODE, maxsteps, maxtime


##############################
# MODEL 2
##############################
def model2():
    # Initial condition
    initial = np.array([ 100, 50 ])
    
    # Stoichiometry (reactions)
    # negative numbers indicate consumed species
    # positive numbers indicate produced species
    n_states = 2
    n_rxns   = 4
    N = np.zeros([n_rxns,n_states])
    N[0,:] = [  0.0,  1.0 ] # y -> 2y      # Prey give birth
    N[1,:] = [  1.0, -1.0 ] # y+x -> 2x    # Hunter eats prey and produces offspring
    N[2,:] = [ -1.0,  0.0 ] # x -> 0       # Hunter dies
    N[3,:] = [  1.0,  0.0 ] # 0 -> x       # Hunter migrates into population
    
    propensityFunction = lambda x,p:[p[0]*x[1], p[1]*x[0]*x[1], p[2]*x[0], p[3] ]

    # Rates
    parsGillespie = [ 1.0, .005, .5, .3 ]   
    parsODE       = [ 1.0, .005 / N_to_conc, .5, .3 * N_to_conc]

    # Desired time vector
    maxsteps = 2000000
    maxtime  = 100
    
    return n_states, N, propensityFunction, initial, parsGillespie, parsODE, maxsteps, maxtime

def main():
    
    plotVariance = 0
    
    # Model specification
    #   S1 + S2 -> S3        k1
    #   S3      -> S2 + S4   k2
    #
    #   k1 = 0.5
    #   k2 = 0.2
    #   V  = 1pL
    #   S1(0) = S2(0) = 10
    #   S3(0) = S4(0) = 0
    names    = ['S1', 'S2', 'S3', 'S4']
    colors   = ['r', 'g', 'b', 'k']
    
    # Load Model 1    
    n_states, N, propensityFunction, initial, parsGillespie, parsODE, maxsteps, maxtime = model1();
    
    # Load Model 2
    #n_states, N, propensityFunction, initial, parsGillespie, parsODE, maxsteps, maxtime = model2();

    # Define time axis
    t = np.linspace(0, maxtime, 1500);
    
    # Calculate single realization of the Gillespie simulation
    i, n_S, tgillespie = gillespie(initial, N, propensityFunction, parsGillespie, maxtime, maxsteps)    
    
    ## Deterministic integration
    def rhs(y, t, N, fun, pars):
        w = fun(y, pars)
        sizes = np.shape(N)
        
        dx = N[0,:]*w[0]
        for i in range(1,sizes[0]):
            dx = dx + N[i,:]*w[i]
        
        return dx
    
    sol = odeint(rhs, initial*N_to_conc, t, args=(N, propensityFunction, parsODE), hmax=10, rtol=1e-14, atol=1e-14 )

    plt.subplot(131)
    plt.title("Deterministic")
    for js in range(0,n_states):
        plt.plot(t, sol[:, js]/N_to_conc, colors[js], label=names[js])
    
    # Plot it (interp1d interpolates the simulation)
    plt.subplot(132)
    plt.title("Single trajectory")
    for js in range(0,n_states):
        f = interpolate.interp1d(tgillespie, n_S[js,:], kind='zero' )
        plt.plot(t, f(t), colors[js], label=names[js])
        
    # Calculate averaged Gillespie simulation
    M = 20 # Number of samples to average
    mn, sds = averagedGillespie( M, initial, N, propensityFunction, parsGillespie, t, maxsteps )
        
    # Plot it
    plt.subplot(133)
    plt.title("Mean " + str(M) + " realizations")
    for js in range(0,n_states):
        plt.plot(t, mn[js,:], colors[js], label=names[js])
        if ( plotVariance ):
            plt.plot(t, mn[js,:] + sds[js,:], colors[js] + '--', label=names[js])
            plt.plot(t, mn[js,:] - sds[js,:], colors[js] + '--', label=names[js])
    
main()