import torch

fMLCF_f0 = torch.load('./fMLCF_f0.pt')
X_mlcf_f0= torch.load('./X_mlcf_f0.pt')
fMLCF_f1 = torch.load('./fMLCF_f1.pt')
X_mlcf_f1= torch.load('./X_mlcf_f1.pt')
fMLCF_f2 = torch.load('./fMLCF_f2.pt')
X_mlcf_f2= torch.load('./X_mlcf_f2.pt')
fCF_f2 = torch.load('./fCF_f2.pt')
X_CF_f2= torch.load('./X_CF_f2.pt')

import matplotlib.pyplot as plt
from illustrative.fl import *

f_true = lambda t: np.exp(-t)
t_f = np.linspace(0,1,1000)
y_f = f_true(t_f)
tgrid = np.linspace(0,1,20000)

fig = plt.figure(figsize=(14,5))
plt.rcParams['font.size'] = '12'

ax_big = plt.subplot2grid((2,3), (0,0), rowspan=2, colspan=1)
y_f_mod = np.r_[y_f[0], y_f[1:] + 0.01]
ax_big.plot(t_f, y_f_mod, label="f", linewidth=5,color='black')
ax_big.plot(t_f, f2(t_f), label="$f_2$", linewidth=2,color="tab:orange")
ax_big.plot(t2, y2, 'o', markersize=4,color="tab:orange")
ax_big.plot(t_f, f1(t_f), label="$f_1$", linewidth=2,color="tab:green")
ax_big.plot(t1, y1, 'o', markersize=4,color="tab:green")
ax_big.plot(t_f, f0(t_f), label="$f_0$", linewidth=2,color="tab:blue")
ax_big.plot(t0, y0, 'o', markersize=4,color="tab:blue")

ax_big.set_xlabel("x")
ax_big.legend()

ax1 = plt.subplot2grid((2,3), (0,1))
ax2 = plt.subplot2grid((2,3), (0,2))
ax3 = plt.subplot2grid((2,3), (1,1))
ax4 = plt.subplot2grid((2,3), (1,2))

ax1.plot(t_f, f0(t_f), label="$f_0$ (MLMC)", linewidth=2,color='tab:blue')
ax1.axhline(y=np.mean(f0(tgrid)), linestyle='--', color='tab:red', linewidth=2, label=r'$\Pi[f_0]$')
ax1.plot(X_mlcf_f0, fMLCF_f0, label="$f_0$ (MLCF)", linewidth=2,color='tab:green')
ax1.legend()
#ax1.set_xlabel("x")

ax2.plot(t_f, f1(t_f) - f0(t_f), label="$f_1-f_0$ (MLMC)", linewidth=2,color='tab:blue')
ax2.axhline(y=np.mean(f1(tgrid) - f0(tgrid)), linestyle='--', color='tab:red', linewidth=2, label=r'$\Pi[f_1-f_0]$')
ax2.plot(X_mlcf_f1, fMLCF_f1, label="$f_1-f_0$ (MLCF)", linewidth=2,color='tab:green')
ax2.legend()
#ax2.set_xlabel("x")


ax3.plot(t_f, f2(t_f) - f1(t_f), label="$f_2-f_1$ (MLMC)", linewidth=2,color='tab:blue')
ax3.axhline(y=np.mean(f2(tgrid) - f1(tgrid)), linestyle='--', color='tab:red', linewidth=2, label=r'$\Pi[f_2-f_1]$')
ax3.plot(X_mlcf_f2, fMLCF_f2, label="$f_2-f_1$ (MLCF)", linewidth=2,color='tab:green')
ax3.legend()
ax3.set_xlabel("x")

ax4.plot(X_CF_f2, fCF_f2, label="f2 (CF)", linewidth=2,color='tab:purple')
ax4.axhline(y=np.mean(f2(tgrid)), linestyle='--', color='tab:red', linewidth=2, label=r'$\Pi[f_2]$')
ax4.legend()
ax4.set_xlabel("x")


# After defining your subplots (ax_big, ax1, ax2, ax3, ax4), add labels:

ax_big.text(-0.15, 1.05, "(a)", transform=ax_big.transAxes,
            fontsize=16, fontweight="bold", va="top", ha="right")

ax1.text(-0.15, 1.05, "(b)", transform=ax1.transAxes,
         fontsize=16, fontweight="bold", va="top", ha="right")

ax2.text(-0.15, 1.05, "(c)", transform=ax2.transAxes,
         fontsize=16, fontweight="bold", va="top", ha="right")

ax3.text(-0.15, 1.05, "(d)", transform=ax3.transAxes,
         fontsize=16, fontweight="bold", va="top", ha="right")

ax4.text(-0.15, 1.05, "(e)", transform=ax4.transAxes,
         fontsize=16, fontweight="bold", va="top", ha="right")


plt.tight_layout()
plt.savefig("illustration.pdf", dpi=600)
plt.show()

