import numpy as np
from numpy import tanh, arctanh, exp, pi, inf, log, cosh, sqrt, square
from numpy import array
from scipy.special import k0
from scipy import integrate
from typing import List
import matplotlib.pyplot as plt

bound = [0, inf] # lower/upper bound for integral
func_M = lambda x, theta, v: (tanh(v + theta * x) - tanh(v - theta * x)) * x * k0(abs(x)) / pi
M = lambda theta, v: integrate.quad(func_M, bound[0], bound[1], args=(theta, v))[0]
M0 = lambda theta: M(theta, v=0)


def save_plot():
    plt.xlim([0, num_iteration])
    plt.ylim([0, None])
    plt.xlabel("iterations")
    plt.title(r"Initialization of $\alpha^{t}$ with $\pi^0=(\frac{1}{2}, \frac{1}{2}), \alpha^0\to\infty$")
    plt.legend()
    plt.grid(color='gray', linestyle='dashed')
    plt.savefig('init.png',  dpi=300)
    plt.show()


if __name__ == "__main__":
    theta = inf # initial value |theta^0|/sigma
    list_theta = []
    num_iteration = 36
    for i in range(num_iteration):
        theta = M0(theta)
        list_theta.append(theta)
        print(f"iteration {i+1}: {round(theta, 4)}")
    plt.plot(range(1,num_iteration+1), list_theta)
    colors = ["#FFCCCC", "#FF6666", "#FF0000"]
    markers = ["v","^", "o"]
    values =[r"$\approx \frac{2}{\pi}$", "<0.31", "<0.1"]
    for iteration, c, m, v in list(zip([1, 3, 20],colors, markers, values)):
        plt.scatter(iteration, list_theta[iteration-1],
                    color=c, marker=m,
                    label=rf'$\alpha^{{{iteration}}} \approx {round(list_theta[iteration - 1], 3)}$'+v)
    # Add transparent filled area between iterations 9 and 36
    plt.fill_between(
        range(9, 37),  # X range
        list_theta[9-1:37-1],  # Y values
        color="blue",
        alpha=0.2,
        label="Estimated Lower/Upper Bounds"
    )
    
    # Customize xticks for specific iterations
    plt.xticks([1, 3, 9, 20, 36])  # Only show ticks for 1, 3, 9, 20, 36
    # Set yticks for specific values
    plt.yticks(
        [2 / pi, 0.31, 0.1], 
        [r"$\frac{2}{\pi}$", r"$0.31$", r"$0.1$"]  # Custom labels
    )
    save_plot()
    plt.show()