#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May  3 16:59:36 2018

@author: mirjam
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import least_squares
from scipy.stats import chi2
from matplotlib import cm

sigma = 0.5
#Function that executes the model
def myModel(p,t):     
    y_out = [p[0]*(1-np.exp(-p[1]*t_i)) for t_i in t]
    return y_out

# function definition for data generation (i)
def mydata(p,t,t_full):
    y = np.empty(np.size(t))
    y = myModel(p,t)
    ydata = y + np.random.normal(scale = sigma, size = np.size(t)) 

    ysim = myModel(p,t_full)
    plotData(ydata,t,ysim,t_full)   
    
    return ydata

def plotData(ydata,tdata,ysim,tsim):
    plt.figure()
    plt.plot(tdata,ydata,'o')
    plt.plot(tsim,ysim,'r', label = 'true')
    plt.title('simulated data')
    plt.ylabel('y')
    plt.xlabel('time')
    plt.show()
    
A = 5
k = 0.5

#parameter vector
p = (A,k)

# Full time points
t_full = np.linspace(0,20,1001)

# now generate the four datasets
t1 = [0.5,1]
y1 = mydata(p, t1,t_full)

t2 = [17,20]
y2 = mydata(p, t2,t_full)

t3 = [0,1,3,10,15,20]
y3 = mydata(p, t3,t_full)

t4 = range(1,20)
y4 = mydata(p, t4,t_full)


# Exercise ii

#Function to choose between the different data realizations
def get_rightData(data_number):
    choice_y = {1: y1, 2: y2, 3: y3, 4: y4}
    which_y = choice_y.get(data_number,y1)
    choice_t = {1: t1, 2: t2, 3: t3, 4: t4}
    which_t = choice_t.get(data_number,t1)
    return (which_y,which_t)

#Function that returns list of residuals of the function, 
#whereby y_true is generated for current parameter set
def leastsquares(p,ys,ts,fix_par=0,par_value = 0):
        
    if fix_par == 1:
        p = [par_value,p]
    elif fix_par == 2:
        p = [p,par_value]
    
    res = np.empty(np.size(ys))
    y_true = myModel(p,ts)
    for iy in range(len(ys)):
        res[iy] = (chosen_y[iy]-y_true[iy])/sigma
    return res


#Optimization
# !!! Specify which data realization to take !!!
data_number = 1
(chosen_y,chosen_t) = get_rightData(data_number)

#For multiple fits, specify nr_fits>1
nr_fits = 1
#This sets boundaries of the two parameters (l = lower, u = upper)
lb = np.array([0,0])
ub = np.array([10,10])

chi2_vec = np.empty(nr_fits)
par_vec = np.empty([nr_fits,len(p)])

for iOpt in range(nr_fits):
    x_0 = np.multiply(np.random.rand(2),ub)
    opt_tmp = least_squares(leastsquares,x_0,args=(chosen_y,chosen_t))
    chi2_vec[iOpt] = opt_tmp.cost
    par_vec[iOpt,] = opt_tmp.x

print("A=%g"%(par_vec[0,0])+" , k=%g"%(par_vec[0,1]))

y_true = myModel(p,t_full)    
y_fit = myModel(par_vec[0,:],t_full)    
plt.figure('residuals')
plt.plot(chosen_t,chosen_y,'o', label = "data")
plt.plot(t_full,y_true,'r', label = 'true')
plt.plot(t_full,y_fit,'b', label = 'fit')
plt.title('fit: data set %d'%(data_number,))
leg = plt.legend(loc='best',ncol = 1, shadow=False, fancybox=True)
plt.show()

#Part 2, profile likelihood
    
# !!! Here, you can set a different data realization and which parameter to scan for profile likelihood !!!
#data_number = 1
fix_par = 1
(chosen_y,chosen_t) = get_rightData(data_number)

#List of parameter values for the parameter that is scanned
#Set how many steps are taken for profile
nr_profile = 101
par_screen = np.linspace(lb[fix_par-1],ub[fix_par-1],nr_profile)

chi2_profile = np.empty(len(par_screen))
model_traj = np.empty([nr_profile,len(t_full)])
for ip in range(nr_profile):
    #Set optimal value of other parameter as starting point
    x0 = par_vec[0,fix_par-1]   
    opt_tmp = least_squares(leastsquares,x0,args=(chosen_y,chosen_t,fix_par,par_screen[ip]))
    chi2_profile[ip]= opt_tmp.cost
    p_tmp = opt_tmp.x
    if fix_par == 1:
        p_tmp = [par_screen[ip],p_tmp]
    elif fix_par == 2:
        p_tmp = [p_tmp,par_screen[ip]]
    model_traj[ip,] = myModel(p_tmp,t_full)


threshold = chi2.ppf(0.95,1)

#Get the parameter name for the figures
choice_par = {1: 'A', 2: 'k'}
which_par = choice_par.get(fix_par,'')

fig_prof = plt.figure()
#plt.semilogy(par_screen,chi2_profile)
plt.plot(par_screen,chi2_profile)
plt.title('Profile likelihood for parameter ' + which_par)
plt.plot([np.min(par_screen), np.max(par_screen)],[min(chi2_profile)+threshold,min(chi2_profile)+threshold],'r--')
plt.ylim(ymax=min(chi2_profile)+5,ymin=min(chi2_profile)-0.1)
plt.ylabel('$\chi^2$ value')
plt.xlabel('parameter value')

#This part gets a color map to nicely render the trajectories
rgb = cm.get_cmap('RdBu')
color_code = plt.Normalize(par_screen)

fig_traj = plt.figure()
for ic in range(nr_profile):
    plt.plot(t_full,np.transpose(model_traj[ic,]),color=rgb(ic/nr_profile))
plt.plot(chosen_t,chosen_y,'o')
plt.title('Model trajectories for different values of par ' + which_par)
plt.ylabel('y')
plt.xlabel('time')


