from EM_update import generate_theta, generate_data, EM_update, EM_update_easy,\
    generate_theta2d, generate_pi,\
    justify_params, calc_error
from typing import List, Tuple
import numpy as np
from numpy import sqrt, arccos, pi
from numpy import linspace, meshgrid, vectorize
from numpy import tanh, arctanh, exp, pi, inf, log, cosh
from numpy import array
from scipy.special import k0
from scipy import integrate
import matplotlib.pyplot as plt

def EM_population(list_theta: List, list_pi: List, sigma: float):
    list_theta_updated, list_pi_updated = [], []
    bound = [0, inf] # lower/upper bound for integral
    func_M = lambda x, theta, v: (tanh(v + theta * x) - tanh(v - theta * x)) * x * k0(abs(x)) / pi
    M = lambda theta, v: integrate.quad(func_M, bound[0], bound[1], args=(theta, v))[0]
    func_N = lambda x, theta, v: (tanh(v + theta * x) + tanh(v - theta * x)) * k0(abs(x)) / pi
    N = lambda theta, v: integrate.quad(func_N, bound[0], bound[1], args=(theta, v))[0]
    for theta0, pi0 in list(zip(list_theta, list_pi)):
        norm_theta = np.linalg.norm(theta0)
        direction = theta0 / norm_theta
        theta = norm_theta / sigma
        v = arctanh(2*pi0-1)
        list_theta_updated.append(M(theta, v)*sigma*direction)
        list_pi_updated.append(arctanh(N(theta, v)))
    return list_theta_updated, list_pi_updated

def draw_trajectory_theoretical(list_theta0, sigma):
    x_values, y_values = [t[0]/sigma for t in list_theta0],\
                        [t[1]/sigma for t in list_theta0]
    for i, (x, y) in enumerate(list(zip(x_values, y_values))):
        label = r'theoretical trajectory of $\theta^t$' if (i == 0) else None
        tr = plt.plot([x, 0], [y, 0], c='#FF0000',linestyle='dashed', zorder=1, label=label)
    
def draw_trajectory_empirical(list_theta: List, sigma: float,
                              with_label: bool):
    x_values, y_values = [t[0]/sigma for t in list_theta],\
                        [t[1]/sigma for t in list_theta]
    label_init = r'initial value $\theta^0$' if with_label else None
    plt.scatter(x_values[0], y_values[0], facecolors='none', edgecolors='#3399FF', marker='o', s=30, label=label_init)
    for i in range(len(x_values)-1):
        label = r'empirical trajectory of $\theta^t$' if ((i == 0) and with_label) else None
        tr = plt.plot([x_values[i], x_values[i+1]], [y_values[i], y_values[i+1]], c='#99ccff',zorder=1, label=label)

def save_plot():
    plt.scatter(0, 0, facecolors='none', edgecolors='#FF0000', marker='*', 
                s=60,zorder=2, label=r"ground truth $\theta^*$")
    plt.xlabel(r'horizontal component of $\theta$')
    plt.ylabel(r'vertical component of $\theta$')
    plt.title(r'Trajectories of $\theta^t$ and $\theta^*$')
    plt.grid(color='gray', linestyle='dashed')
    plt.xlim([-0.75, 0.75])
    plt.ylim([-0.75, 0.75])
    plt.axis('equal')
    plt.legend(loc='upper right')
    plt.savefig('EM_trajectory.png', bbox_inches='tight', dpi=300)
    plt.show()


if __name__ == "__main__":
    # generating data samples
    num_sample = 2000
    dimension = 2
    theta_star, pi_star = np.asarray([1., 0]), 0.7
    SNR, norm_star = 0, np.linalg.norm(theta_star)
    sigma = 1.
    theta_star *= (SNR/norm_star)*sigma
    seed_data, seed_latent, seed_noise = 0, 0, 0
    X, Y = generate_data(num_sample, sigma,
                  seed_data, seed_latent, seed_noise,
                  theta_star, pi_star)
    # EM trials with different initial theta0, pi0
    num_init = 10
    T = 15
    bound = 2 # theta^0 sampled from [-bound, +bound]^2
    list_theta0 = []
    for i in range(num_init):
        seed, with_label = i + 123, (i == 0)
        theta0 = generate_theta2d(seed, dimension=dimension, bound=bound)
        pi0 = generate_pi(seed)
        list_theta0.append(theta0)
        # EM update
        # list_theta, list_pi = EM_update_easy(X, Y, sigma, T,
        list_theta, list_pi = EM_update(X, Y, sigma, T,
                                        theta0,pi0)
        draw_trajectory_empirical(list_theta, sigma, with_label)
    draw_trajectory_theoretical(list_theta0, sigma)
    save_plot()
