import numpy as np
from numpy import tanh, arctanh, exp, pi, inf, log, cosh
from numpy import array
from scipy.special import k0
from scipy import integrate
import matplotlib.pyplot as plt

bound = [0, inf] # lower/upper bound for integral
func_M = lambda x, theta, v: (tanh(v + theta * x) - tanh(v - theta * x)) * x * k0(abs(x)) / pi
M = lambda theta, v: integrate.quad(func_M, bound[0], bound[1], args=(theta, v))[0]
func_N = lambda x, theta, v: (tanh(v + theta * x) + tanh(v - theta * x)) * k0(abs(x)) / pi
N = lambda theta, v: integrate.quad(func_N, bound[0], bound[1], args=(theta, v))[0]

def EM_population(theta0: float, v0: float, num_iteration: int):
    theta, v = theta0, v0
    for _ in range(num_iteration):
        theta, v = M(theta, v), arctanh(N(theta, v))
    return theta, v

def save_plot():
    plt.xlabel(r'$\beta^{0}:=||\pi^0-\frac{1}{2}||_1$')
    plt.ylabel(r'$\beta^{T}:=||\pi^T-\frac{1}{2}||_1$')
    plt.title(r'relations of converged $\beta^T$ and $\beta^0$')
    plt.grid(color='gray', linestyle='dashed')
    plt.axis('equal')
    plt.xlim([0., 1.])
    plt.ylim([0., 1.])
    plt.legend(loc='upper left')
    plt.savefig('bound_mixweight.png', bbox_inches='tight', dpi=300)
    plt.show()

if __name__ == "__main__":
    num_iteration = 100
    colors = ["#aab7e3", "#7990db", "#4064d9"]
    markers = ["v","^", "o"]
    list_alpha0 = [0.1, 0.3, 0.5]
    for alpha0, c, m in list(zip(list_alpha0,colors,markers)):
        list_beta0, list_betaT = np.linspace(0.01, 0.99, num=10), []
        for beta0 in list_beta0:
            v0 = arctanh(beta0)
            alphaT, vT = EM_population(alpha0, v0, num_iteration)
            list_betaT.append(tanh(vT))
        list_betaT = np.array(list_betaT)
        plt.plot(list_beta0, list_betaT, "--", color=c, marker=m,
                 label=r"$\alpha^0:=||\theta^0||/\sigma={}$"
                 .format(round(alpha0, 1)))
    save_plot()
