import numpy as np
from numpy import tanh, arctanh, exp, pi, inf, log, cosh
from numpy import array
from scipy.special import k0
from scipy import integrate
import matplotlib.pyplot as plt

bound = [0, inf] # lower/upper bound for integral
func_M = lambda x, theta, v: (tanh(v + theta * x) - tanh(v - theta * x)) * x * k0(abs(x)) / pi
M = lambda theta, v: integrate.quad(func_M, bound[0], bound[1], args=(theta, v))[0]
func_N = lambda x, theta, v: (tanh(v + theta * x) + tanh(v - theta * x)) * k0(abs(x)) / pi
N = lambda theta, v: integrate.quad(func_N, bound[0], bound[1], args=(theta, v))[0]

def EM_population(theta0: float, v0: float, num_iteration: int):
    theta, v = theta0, v0
    list_alpha, list_beta = [], []
    for _ in range(num_iteration):
        list_alpha.append(theta)
        list_beta.append(tanh(v))
        theta, v = M(theta, v), arctanh(N(theta, v))
    return np.array(list_alpha), np.array(list_beta)

def save_plot(ax, layout='vertical'):
    # First subplot settings
    ax[0].set_xlabel(r'$[\beta^t]^2$')
    ax[0].set_ylabel(r'$(\alpha^{t}-\alpha^{t+1})/\alpha^{t}$')
    ax[0].set_title(r'$(\alpha^{t}-\alpha^{t+1})/\alpha^{t}$ vs. $[\beta^t]^2$')
    ax[0].set_xlim([0., 1.])
    ax[0].set_ylim([0., 1.])
    ax[0].plot([0, 1], [0, 1], c='#FF0000',linestyle='dashed', label="approximate dynamics")
    ax[0].grid(color='gray', linestyle='dashed')
    ax[0].legend(loc='upper left')
    # Second subplot settings
    ax[1].set_title(r'$(\beta^{t}-\beta^{t+1})/\beta^{t}$ vs. $\alpha^t\alpha^{t+1}$')
    ax[1].set_xlabel(r'$\alpha^t\alpha^{t+1}$')
    ax[1].set_ylabel(r'$(\beta^{t}-\beta^{t+1})/\beta^{t}$')
    ax[1].set_xlim([1e-3, 1e-2])
    ax[1].set_ylim([1e-3, 1e-2])
    ax[1].plot([1e-3, 1e-2], [1e-3, 1e-2], c='#FF0000',linestyle='dashed', label="approximate dynamics")
    ax[1].grid(color='gray', linestyle='dashed')
    ax[1].legend(loc='upper left')
    plt.savefig('dynamic.png', dpi=300)
    plt.show()

if __name__ == "__main__":
    num_iteration = 2
    alpha0 = 0.1
    list_beta0 = np.linspace(0.1, 1-np.finfo(float).eps, num=10)
    
    # Choose layout: 'vertical' or 'horizontal'
    layout = 'horizontal'  # Change this to 'horizontal' for horizontal layout
    
    # Set figure size to maintain consistent subplot dimensions
    if layout == 'vertical':
        fig, ax = plt.subplots(2, 1)  # Taller figure for vertical layout
        fig.subplots_adjust(hspace=0.5)
    else:
        fig, ax = plt.subplots(1, 2, figsize=(15, 3))  # Wider figure for horizontal layout
        # Adjust margins: left, bottom, right, top
        fig.subplots_adjust(left=0.05, bottom=0.2, right=0.95, top=0.9, wspace=0.3)
    
    for ind, beta0 in enumerate(list_beta0):
        c, m = "#4064d9", "v"
        v0 = arctanh(beta0)
        alphas, betas = EM_population(alpha0, v0, num_iteration)
        label = None if not (ind == 0) else "EM updates"
        ax[0].scatter(betas[:-1]*betas[:-1], (alphas[:-1]-alphas[1:])/alphas[:-1], color=c, marker=m, label = label)
        ax[1].scatter(alphas[:-1]*alphas[1:], (betas[:-1]-betas[1:])/betas[:-1], color=c, marker=m, label = label)
    save_plot(ax, layout)