import numpy as np
from numpy import tanh, arctanh, exp, pi, inf, log, cosh, sqrt, square
from numpy import array
from scipy.special import k0
from scipy import integrate
from typing import List
import matplotlib.pyplot as plt

bound = [0, inf] # lower/upper bound for integral
func_M = lambda x, theta, v: (tanh(v + theta * x) - tanh(v - theta * x)) * x * k0(abs(x)) / pi
M = lambda theta, v: integrate.quad(func_M, bound[0], bound[1], args=(theta, v))[0]
func_N = lambda x, theta, v: (tanh(v + theta * x) + tanh(v - theta * x)) * k0(abs(x)) / pi
N = lambda theta, v: integrate.quad(func_N, bound[0], bound[1], args=(theta, v))[0]

def EM_population(theta0: float, v0: float, num_iteration: int):
    theta, v = theta0, v0
    list_alpha, list_beta = [], []
    for _ in range(num_iteration):
        list_alpha.append(theta)
        list_beta.append(tanh(v))
        theta, v = M(theta, v), arctanh(N(theta, v))
    return np.array(list_alpha), np.array(list_beta)

def draw_trajectory_theoretical(list_alpha0, colors_theory, num_iteration):
    bound = lambda t, alpha0 : (1/alpha0)/(sqrt(6*t + 22*log(1.2*t+1)+square(1/alpha0)))
    for alpha0, color in list(zip(list_alpha0, colors_theory)):
        label = r'lower bounds for $\alpha^t/\alpha^0$ ($\alpha^0={}$)'.format(round(alpha0, 2))
        plt.plot(range(num_iteration), list(bound(t, alpha0) for t in range(num_iteration)), c=color,linestyle='dashed', label=label)

def save_plot():
    plt.xlabel("iterations")
    plt.title(r"$\alpha^{t}$ vs. lower bounds $\frac{1}{\sqrt{6t+22\ln(1.2t+1)+(\frac{1}{\alpha^0})^2}}$")
    plt.legend()
    plt.grid(color='gray', linestyle='dashed')
    plt.savefig('balanced_LB.png',  dpi=300)
    plt.show()

if __name__ == "__main__":
    num_iteration = 200
    colors = ["#aab7e3", "#7990db", "#4064d9"]
    colors_theory = ["#FFCCCC", "#FF6666", "#FF0000"]
    markers = ["v","^", "o"]
    list_alpha0, v0 = [0.02, 0.05, 0.1], 0
    for (alpha0, c, m) in list(zip(list_alpha0,colors,markers)):
        alphas, betas = EM_population(alpha0, v0, num_iteration)
        plt.plot(range(len(alphas)), alphas/alpha0, color=c, marker=m, markevery=25,
                 label = r'iterations of $\alpha^t/\alpha^0$ ($\alpha^0$={})'.format(round(alpha0, 2)))
    draw_trajectory_theoretical(list_alpha0, colors_theory, num_iteration)
    save_plot()
