import pickle
import matplotlib.pyplot as plt
import seaborn as sns
## Load values

params = {
    "axes.labelsize": 14,
    "font.size": 14,
    "legend.fontsize": 14,
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "text.usetex": True,
}
plt.rcParams.update(params)
sns.set_palette("colorblind")
sns.set_style("ticks")

obj = "lstsq"
max_iter = 100000
n_samples = 10
batch_size = 1


file_names = [
    f"results/{obj}_T-{max_iter}_n-{n_samples}_bs-{batch_size}-cst.pickle",
    f"results/{obj}_T-{max_iter}_n-{n_samples}_bs-{batch_size}-dec.pickle",
    f"results/{obj}_interpol_T-{max_iter}_n-{n_samples}_bs-{batch_size}-cst.pickle",
    f"results/{obj}_T-{max_iter}_n-{n_samples}_bs-{batch_size}-cst-img.pickle",
]

data = []
legends = []

for file_name in file_names:
    with open(file_name, "rb") as f:
        file_data = pickle.load(f)
        data.append(file_data["iterates"])
        data.append(file_data["derivatives"])
        legends.append(file_data["legends"])

(
    fs_cst,
    dys_cst,
    fs_dec,
    dys_dec,
    fs_interpol_cst,
    dys_interpol_cst,
    fs_cst_img,
    dys_cst_img,
) = data
legends_cst, legends_dec = legends[:2]

# cst_alphas
fig, axs = plt.subplots(2, 4, figsize=(12, 4))

axs[0, 0].semilogy(fs_cst)
axs[0, 0].set_ylabel("$f(x) - f(x^\star)$")
axs[0, 0].set_title("Constant stepsize")
axs[1, 0].semilogy(dys_cst)
axs[1, 0].set_ylabel("$\|\partial_\\theta f(x) - \partial_\\theta f(x^\star)\|$")
axs[1, 0].set_xlabel("$k$")

axs[0, 1].semilogy(fs_dec, "--")
axs[0, 1].set_title("Decreasing stepsize")
axs[1, 1].semilogy(dys_dec, "--")
axs[1, 1].set_xlabel("$k$")

axs[0, 2].semilogy(fs_interpol_cst)
axs[0, 2].set_title("Double interpolation")
axs[1, 2].semilogy(dys_interpol_cst)
axs[1, 2].set_xlabel("$k$")

# dec_alphas

axs[0, 3].semilogy(fs_cst_img)
axs[0, 3].set_title("Simple interpolation")
axs[1, 3].semilogy(dys_cst_img)
axs[1, 3].set_xlabel("$k$")

# axs[0,0].legend(legends_cst, legends_dec, loc="lower center", bbox_to_anchor=(0, 0.5))
# axs[1,0].legend(legends_dec, loc="lower center", bbox_to_anchor=(0, 0.5))

leg1 = fig.legend(
    [
        "$\eta_0$",
        "$\eta_0/2$",
        "$\eta_0/10$",
        "$\eta_0 k^{-1}$",
        "$\eta_0 k^{-1/2}$",
        "$\eta_0 k^{-1/4}$",
        "$\eta_0/\log(k)$",
    ],
    bbox_to_anchor=(0.5, -0.07),
    loc="lower center",
    ncol=7,
)
fig.add_artist(leg1)
fig.tight_layout()
plt.savefig(
    f"figures/{obj}_T-{max_iter}_n-{n_samples}_bs-{batch_size}.pdf", bbox_inches="tight"
)
plt.show()
