#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun  7 10:41:33 2018

@author: helge
"""
import numpy as np
import matplotlib.pyplot as plt

# General setting
a = 0
b = 1
M = 200

# Different N if wanted
Ns = np.array([5, 7, 10, 15, 20, 25, 30, 40, 50, 70, 80, 100, 150, 200, 300, 500, 1000])
Ns = np.array([1000])

#Function to determine a,b from least-squares analytically
def ls_ab(yi,xi):
    b_est = np.sum(np.multiply((xi-np.mean(xi)),(yi-np.mean(yi)))) / np.sum((xi-np.mean(xi))**2)
    a_est = np.mean(yi) - b_est*np.mean(xi)
    return (a_est,b_est)

#Bisection search of root of function
def bisection(f,a,b,epsilon=1e-5):     
    Nmax = 1000
    N = 1
    
    while N <= Nmax:
        c = float((a+b))/2
        if (f(c)==0) or (float((b-a))/2 < epsilon):
            break
        
        N += 1
        if np.sign(f(c)) == np.sign(f(a)):
            a = c
        else:
            b = c
        
    return c
    
ins = 0
var_a_LS = np.empty([len(Ns),1])
var_b_LS = np.empty([len(Ns),1])
var_a_MLE = np.empty([len(Ns),1])
var_b_MLE = np.empty([len(Ns),1])
eff_a = np.empty([len(Ns),1])
eff_b = np.empty([len(Ns),1])

# Loop over Ns if wanted, otherwise comment out
for ins in range(0,len(Ns)):
    N = Ns[ins]
    xi = np.random.randn(N,M)

    # Different error distributions
    
    # (1) Gaussian noise
    # ei = np.random.randn(N,M)
    
    # (2) Cauchy distributed random variables
    #ei = np.divide(np.random.randn(N,M),np.random.randn(N,M))
    
    # (3) Generate double exponential distributed errors
    def laplace(scale):
        return -scale * np.log(np.random.rand()/np.random.rand())
        #Alternatively:
        #e1 = 2*(np.random.rand()-0.5)
        #return -np.sign(e1) * scale * log(1-np.abs(e1))
    
    #Get errors, two different ways, second is preferred
    #ei = [*map(laplace, np.ones((N,M), dtype = float))]
    ei = np.reshape([laplace(1) for ni in range(N) for mi in range(M)],(N,M))
    
    #Generate data
    yi = a + b*xi + ei
    # plt.plot(xi[:,1],yi[:,1],'o')
    
    # Invoke arrays to store least-squares, robust MLE estimates
    a_ls = np.empty([M,1], dtype = float)
    b_ls = np.empty([M,1], dtype = float)
    a_mle = np.empty([M,1], dtype = float)
    b_mle = np.empty([M,1], dtype = float)
    
    for im in range(M):
        a_ls[im], b_ls[im] = ls_ab(yi[:,im],xi[:,im])
        
        #MLE estimation
        a_tol = b_tol = float("inf")
        b_est = b_ls[im]
        a_est = a_ls[im]
        
        a_tol = float("inf")
        b_tol = float("inf")
        
        #Iterate until tolerance is hit
        while (a_tol > 1e-8) and (b_tol > 1e-8):
            a_est_tmp = np.median(yi[:,im] - b_est * xi[:,im])
            
            #Define MLE function for slope
            def f_b(b_now):
                return np.sum(np.multiply(xi[:,im],np.sign(yi[:,im]-b_now*xi[:,im]-a_est_tmp)))
            
            #Do line search for the minimum of the function fb
            b_est_tmp = bisection(f_b,-5,5)
        
            a_tol = np.abs(a_est_tmp - a_est)
            b_tol = np.abs(b_est_tmp - b_est)
            a_est = a_est_tmp
            b_est = b_est_tmp
        
        a_mle[im] = a_est
        b_mle[im] = b_est
        
    print("Mean ± Variance of LS: a " + str(np.round(np.mean(a_ls),4)) + " ± " + str(np.round(np.var(a_ls),4)) + " , b " + str(np.round(np.mean(b_ls),4)) + " ± " + str(np.round(np.var(b_ls),4)) )
    print("Mean ± Variance of MLE: a " + str(np.round(np.mean(a_mle),4)) + " ± " + str(np.round(np.var(a_mle),4)) + " , b " + str(np.round(np.mean(b_mle),4)) + " ± " + str(np.round(np.var(b_mle),4)) )
    
    # Calculate efficiency
    eff_A = np.var(a_mle)/np.var(a_ls)
    eff_B = np.var(b_mle)/np.var(b_ls)
    
    print("Efficiency of intercept is " + str(eff_A) + " and for slope its " + str(eff_B))
    
    # If looped over Ns, store variances
    if ins>0:
        var_a_LS[ins] = np.var(a_ls)
        var_b_LS[ins] = np.var(b_ls)
        var_a_MLE[ins] = np.var(a_mle)
        var_b_MLE[ins] = np.var(b_mle)
        eff_a[ins] = eff_A
        eff_b[ins] = eff_B


# PLOTTING
plt.figure(1)
plt.subplot(221)
plt.hist(a_ls,bins=30)
plt.title("Intercept of least-squares est")
plt.subplot(222)
plt.hist(b_ls,bins=30)
plt.title("Slope of least-squares est")
plt.subplot(223)
plt.hist(a_mle,bins=30)
plt.title("Intercept of max likelihood est")
plt.subplot(224)
plt.hist(b_mle,bins=30)
plt.title("Slope of max likelihood est")

# Plot variance course for list of Ns
if ins>0:
    plt.figure(2)
    plt.subplot(121)
    [a1,a2] = plt.semilogy(Ns[1:],np.array([var_a_LS[1:,0],var_a_MLE[1:,0]]).transpose(),'o' )
    plt.ylabel("Variance of Intercept")
    plt.xlabel("Number of data points")
    plt.legend([a1,a2],["LS","MLE"])
    plt.subplot(122)
    [b1,b2] = plt.semilogy(Ns[1:],np.array([var_b_LS[1:,0],var_b_MLE[1:,0]]).transpose(),'o')
    plt.ylabel("Variance of Slope")
    plt.xlabel("Number of data points")
    plt.legend([b1,b2],["LS","MLE"])
    plt.figure(3)
    plt.plot(Ns[1:],np.array([eff_a[1:,0],eff_b[1:,0]]).transpose(),'o')
