import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

B0 = np.load('tensor/B0.npy')
C0 = np.load('tensor/C0.npy')
B = np.load('tensor/B.npy')
C = np.load('tensor/C.npy')
mean = np.load('sim_mean.npy')
std = np.load('sim_std.npy')
means = np.load('tensor/means.npy')
stds = np.load('tensor/stds.npy')

CB0 = np.matmul(C0.T, B0)
CB = np.matmul(C.T, B)

p=3

plt.rcParams['text.usetex'] = True
mpl.rcParams['mathtext.fontset'] = 'custom'
mpl.rcParams['mathtext.it'] = 'Times New Roman:italic'
mpl.rcParams['mathtext.bf'] = 'Times New Roman:italic'
plt.rcParams.update({
    "text.usetex": True,
    "text.latex.preamble": r"""
        \usepackage{amsmath}
        \usepackage{bm}
    """
})

plt.figure(figsize=(1.1*p, 1.2*p))
plt.imshow(CB0, cmap='viridis', interpolation='nearest')
plt.title(r'$\bm{C}^{\top}(0)\bm{W}_B(0)$', fontsize=16)
plt.axis('off')
plt.savefig('fig/exp1.png', dpi=1200)

plt.figure(figsize=(1.1*p, 1.2*p))
plt.imshow(CB, cmap='viridis', interpolation='nearest')
plt.title(r'$\bm{C}^{\top}(t)\bm{W}_B(t)$', fontsize=16)
plt.axis('off')
plt.savefig('fig/exp2.png', dpi=1200)

plt.figure(figsize=(1.3*p, 1.2*p))
N = 50
x = np.linspace(0, N, N)
plt.plot(x, mean)
plt.fill_between(x, mean-std, mean+std, color='blue', alpha=0.2)
# plt.title('cosine similarity', fontsize=16)
plt.xlabel(r'recurrent step $l$', fontsize=16)
plt.ylabel(r'$cos(\tilde{\bm{h}}_l, \bm{w})$', fontsize=16)
plt.savefig('fig/exp3.png', dpi=1200, bbox_inches='tight')

d = 4
Ns = np.arange(4, 82, 2)
y_ = []
for N in Ns:
    y_.append(30 / N)

plt.figure(figsize=(1.3*p, 1.2*p))
plt.xscale("log")
plt.yscale("log")
xticks = [4, 10, 80]
yticks = [0.2, 1, 5]

plt.xticks(xticks, labels=[f"{x}" for x in xticks])
plt.yticks(yticks, labels=[f"{y}" for y in yticks])
plt.plot(Ns, means, label='experimental loss')
plt.plot(Ns, y_, label='y = 3 d(d+1) / (2 N)')
plt.legend()
plt.xlabel(r'token length $N$', fontsize=16)
plt.ylabel('loss', fontsize=16)
plt.savefig('fig/exp4.png', dpi=1200, bbox_inches='tight')

y = []
for N in Ns:
    alpha = np.exp(- np.log(2) / N)
    beta1 = alpha ** 2 * (1 - alpha ** N) ** 2 + (d + 1) * alpha ** 2 * (1 - alpha) * (1 - alpha ** (2 * N)) / (1 + alpha)
    beta3 = alpha * (1 - alpha ** N)
    loss = d / 2 * (1 - (beta3 ** 2 / beta1))
    y.append(loss)

plt.figure(figsize=(1.3*p, 1.2*p))
plt.plot(Ns, means, label='experimental loss')
plt.plot(Ns, y, label='theoretical loss', linestyle='--')
plt.fill_between(Ns, means-stds, means+stds, color='blue', alpha=0.2)
plt.legend()
# plt.title('loss', fontsize=16)
plt.xlabel(r'token length $N$', fontsize=16)
plt.ylabel('loss', fontsize=16)
plt.savefig('fig/exp5.png', dpi=1200, bbox_inches='tight')


plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(CB0, cmap='viridis', interpolation='nearest', aspect='auto')
plt.savefig('fig/CB0.png')


plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(CB, cmap='viridis', interpolation='nearest', aspect='auto')
plt.savefig('fig/CB.png')
