import torch
import numpy as np

import matplotlib.pyplot as plt
from pathlib import Path

K = 1
train_samples = 1000
batch_size = 10
epochs = 100

models = ['e2e', 'ubg', 'pnp']

cmaps = plt.cm.get_cmap('tab10')

settings = [
    [20, 20, 0.95],
    [50, 50, 0.95],
    [100, 100, 0.95],
]

plt.close('all')
fig, ax = plt.subplots(
    1, len(settings), layout="constrained", figsize=([9.02, 2.2]), sharex=True)

lr = 1e-2
for j, setting in enumerate(settings):
    p = setting[0]
    n = setting[1]
    sparsity_degree = setting[2]

    for i, model in enumerate(models):
        try:
            path = Path(f"./{model.upper()}/tests_UBG/train_samples_{train_samples}/test_samples_{100}/p_{p}/n_{n}/train_batch_size_{batch_size}/test_batch_size_{100}/precision_sparsity_{sparsity_degree}/dataloader_shuffle_{1}/K_{K}/training_seed_{101}/epochs_{epochs}/lr_{lr}/scheduler_patience_{10000}/scheduler_factor_{0.8}/scheduler_min_lr_{0.0001}/loss_discount_factor_{1}/")
            avg_train_losses = torch.load(path / 'avg_train_losses.pt')
        except FileNotFoundError:
            continue

        std_train_losses = torch.load(path / 'std_train_losses.pt')
        avg_test_losses = torch.load(path / 'avg_test_losses.pt')
        std_test_losses = torch.load(path / 'std_test_losses.pt')
        avg_logdet_losses = np.array(
            torch.load(path / 'avg_logdet_losses.pt'))
        std_logdet_losses = np.array(
            torch.load(path / 'std_logdet_losses.pt'))
        avg_train_f1_scores = np.array(
            torch.load(path / 'avg_train_f1_scores.pt'))
        avg_test_f1_scores = np.array(
            torch.load(path / 'avg_test_f1_scores.pt'))
        std_train_f1_scores = np.array(
            torch.load(path / 'std_train_f1_scores.pt'))
        std_test_f1_scores = np.array(
            torch.load(path / 'std_test_f1_scores.pt'))
        avg_train_losses = np.array(avg_train_losses)
        avg_test_losses = np.array(avg_test_losses)

        if model == 'ubg':
            color = 'darkblue'
        elif model == 'pnp':
            color = 'steelblue'
        elif model == 'e2e':
            color = 'royalblue'

        ax[j].semilogy(avg_test_losses, linestyle='solid', color=color, lw=2,
                       label=f"$\\bf{{{model.upper()} (ours)}}$")
        ax[j].fill_between(x=np.arange(len(avg_test_losses)),
                           y1=avg_test_losses+std_test_losses,
                           y2=avg_test_losses-std_test_losses, color=color, alpha=0.2)

    # Load GLAD's results
    try:
        L = 1
        glad_Theta_avg_nmses = np.array(torch.load(
            f'./GLAD/glad_results/test_avg_Theta_NMSE_D{p}_n{n}_sparsity{sparsity_degree}_L1_return_Z.pt'))
        glad_Theta_std_nmses = np.array(torch.load(
            f'./GLAD/glad_results/test_std_Theta_NMSE_D{p}_n{n}_sparsity{sparsity_degree}_L1_return_Z.pt'))

        glad_Z_avg_nmses = np.array(torch.load(
            f'./GLAD/glad_results/test_avg_Z_NMSE_D{p}_n{n}_sparsity{sparsity_degree}_L{L}_return_Z.pt'))
        glad_Z_std_nmses = np.array(torch.load(
            f'./GLAD/glad_results/test_std_Z_NMSE_D{p}_n{n}_sparsity{sparsity_degree}_L{L}_return_Z.pt'))

    except FileNotFoundError:
        continue

    ax[j].semilogy(glad_Z_avg_nmses, linestyle='dotted', color="darkslategray", lw=2,
                   label=f"GLAD-Z (non-SPD)")
    ax[j].fill_between(x=np.arange(len(glad_Z_avg_nmses)),
                       y1=glad_Z_avg_nmses+glad_Z_std_nmses,
                       y2=glad_Z_avg_nmses-glad_Z_std_nmses, color="darkslategray", alpha=0.2)

    ax[j].semilogy(glad_Theta_avg_nmses, linestyle='dotted', color="lightseagreen", lw=2,
                   label=f"GLAD-$\Theta$ (non-sparse)")
    ax[j].fill_between(x=np.arange(len(glad_Theta_avg_nmses)),
                       y1=glad_Theta_avg_nmses+glad_Theta_std_nmses,
                       y2=glad_Theta_avg_nmses-glad_Theta_std_nmses, color="lightseagreen", alpha=0.2)

    # Load traditional solvers results
    glasso_nmses_by_cv = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_NMSEs_by_cv.npy'))
    glasso_f1s_by_cv = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_F1s_by_cv.npy'))

    nmses_by_ledoitwolf = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_NMSEs_by_ledoitwolf.npy'))
    f1s_by_ledoitwolf = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_F1s_by_ledoitwolf.npy'))

    nmses_by_oas = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_NMSEs_by_oas.npy'))
    f1s_by_oas = np.array(np.load(
        f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{100}_F1s_by_oas.npy'))

    ax[j].hlines(y=glasso_nmses_by_cv.mean(), xmin=0.0, xmax=epochs, color='tab:red', linestyle="solid", lw=2,
                 label=f"GLasso")
    ax[j].fill_between(x=np.arange(len(avg_test_losses)),
                       y1=glasso_nmses_by_cv.mean() + glasso_nmses_by_cv.std(),
                       y2=glasso_nmses_by_cv.mean() - glasso_nmses_by_cv.std(), color='tab:red', alpha=0.2)

    ax[j].hlines(y=nmses_by_ledoitwolf.mean(), xmin=0.0, xmax=epochs, color='darkred', linestyle="dotted", lw=2,
                 label=f"Ledoit-Wolf (non-sparse)")
    ax[j].fill_between(x=np.arange(len(nmses_by_ledoitwolf)),
                       y1=nmses_by_ledoitwolf.mean() + nmses_by_ledoitwolf.std(),
                       y2=nmses_by_ledoitwolf.mean() - nmses_by_ledoitwolf.std(), color='darkred', alpha=0.2)

    ax[j].hlines(y=nmses_by_oas.mean(), xmin=0.0, xmax=epochs, color='tomato', linestyle="dotted", lw=2,
                 label=f"OAS (non-sparse)")
    ax[j].fill_between(x=np.arange(len(nmses_by_oas)),
                       y1=nmses_by_oas.mean() + nmses_by_oas.std(),
                       y2=nmses_by_oas.mean() - nmses_by_oas.std(), color='tomato', alpha=0.2)

    if j == 0:
        ax[j].set_ylabel("NMSE", fontsize=14)
    ax[j].set_xlabel("Epochs", fontsize=14)

    if sparsity_degree == 0.7:
        ax[j].set_title(
            f"$p={p}, n={n}$\nWeakly sparse", fontsize='14')
    elif sparsity_degree == 0.95:
        ax[j].set_title(
            f"$p={p}, n={n}$", fontsize='14')

    ax[j].grid(which='both', alpha=0.9)


handles, labels = ax[j].get_legend_handles_labels()
legend_order = [6, 7, 3, 4, 5, 1, 2, 0]
ax[j].legend([handles[i] for i in legend_order], [labels[i]for i in legend_order],
             bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show(block=False)
