import torch
import numpy as np
from sklearn.covariance import GraphicalLassoCV, LedoitWolf, OAS
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt

sparsity_degrees = [0.95]
ps = [100]
ns = [100]

test_samples = 100

for sparsity_degree in sparsity_degrees:
    for p in ps:
        for n in ns:
            try:
                test_matrices = torch.load(
                    f'./data/test/p_{p}_n_{n}_density_{sparsity_degree}_size_{10_000}_random_state_{1}.pt')
                test_set = test_matrices[:test_samples]
            except FileNotFoundError:
                continue

            f1s_by_cv = []
            nmses_by_cv = []

            nmses_by_S_inv = []
            f1s_by_S_inv = []

            nmses_by_ledoitwolf = []
            f1s_by_ledoitwolf = []

            nmses_by_mincovdet = []
            f1s_by_mincovdet = []

            nmses_by_oas = []
            f1s_by_oas = []

            for test_matrix_id, (S, Theta_true, Sigma_true, X, alpha) in enumerate(test_set):
                print(test_matrix_id)

                cov = S.clone().squeeze().detach().cpu().numpy()

                ##############
                # Cross-validation
                ##############
                precision_by_cv = GraphicalLassoCV().fit(X).precision_
                nmse_by_cv = (torch.linalg.matrix_norm(torch.tensor(precision_by_cv) - Theta_true, ord='fro')**2 /
                              torch.linalg.matrix_norm(Theta_true, ord='fro')**2)

                nmses_by_cv.append(nmse_by_cv)
                print(f"GLasso: {nmse_by_cv}")

                precision_support_by_cv = precision_by_cv.copy()
                precision_support_by_cv[precision_support_by_cv != 0.] = 1.
                Theta_true_support = Theta_true.numpy().copy()
                Theta_true_support[Theta_true != 0.] = 1.

                f1_by_cv = f1_score(Theta_true_support.flatten(),
                                    precision_support_by_cv.flatten())
                f1s_by_cv.append(f1_by_cv)

                ###############
                # S inversion
                ###############
                precision_by_S_inv = np.linalg.pinv(cov)
                nmse_by_S_inv = (torch.linalg.matrix_norm(torch.tensor(precision_by_S_inv) - Theta_true, ord='fro')**2 /
                                 torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                nmses_by_S_inv.append(nmse_by_S_inv)

                precision_support_by_S_inv = precision_by_S_inv.copy()
                precision_by_S_inv[precision_support_by_S_inv != 0.] = 1.
                f1_by_S_inv = f1_score(Theta_true_support.flatten(),
                                       precision_by_S_inv.flatten())
                f1s_by_S_inv.append(f1_by_S_inv)
                print(f"S inv: {nmse_by_S_inv}")

                ###############
                # Ledoit-Wolf
                ###############
                precision_by_ledoitwolf = LedoitWolf().fit(X).precision_
                nmse_by_ledoitwolf = (torch.linalg.matrix_norm(torch.tensor(precision_by_ledoitwolf) - Theta_true, ord='fro')**2 /
                                      torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                nmses_by_ledoitwolf.append(nmse_by_ledoitwolf)

                precision_support_by_ledoitwolf = precision_by_ledoitwolf.copy()
                precision_by_ledoitwolf[precision_support_by_ledoitwolf != 0.] = 1.
                f1_by_ledoitwolf = f1_score(Theta_true_support.flatten(),
                                            precision_by_ledoitwolf.flatten())
                f1s_by_ledoitwolf.append(f1_by_ledoitwolf)
                print(f"Ledoit-Wolf: {nmse_by_ledoitwolf}")

                ################
                # OAS
                ################
                precision_by_oas = OAS().fit(X).precision_
                nmse_by_oas = (torch.linalg.matrix_norm(torch.tensor(precision_by_oas) - Theta_true, ord='fro')**2 /
                               torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                nmses_by_oas.append(nmse_by_oas)

                precision_support_by_oas = precision_by_oas.copy()
                precision_by_oas[precision_support_by_oas != 0.] = 1.
                f1_by_oas = f1_score(Theta_true_support.flatten(),
                                     precision_by_oas.flatten())
                f1s_by_oas.append(f1_by_oas)
                print(f"OAS: {nmse_by_oas}")

            # Convert everything to numpy arrays
            nmses_by_cv = np.array(nmses_by_cv)
            f1s_by_cv = np.array(f1s_by_cv)

            nmses_by_S_inv = np.array(nmses_by_S_inv)
            f1s_by_S_inv = np.array(f1s_by_S_inv)

            nmses_by_ledoitwolf = np.array(nmses_by_ledoitwolf)
            f1s_by_ledoitwolf = np.array(f1s_by_ledoitwolf)

            nmses_by_mincovdet = np.array(nmses_by_mincovdet)
            f1s_by_mincovdet = np.array(f1s_by_mincovdet)

            nmses_by_oas = np.array(nmses_by_oas)
            f1s_by_oas = np.array(f1s_by_oas)

            plt.close('all')
            fig, axs = plt.subplots(1, 2, sharey=True, tight_layout=True)
            plt.suptitle(
                f'sparsity = {sparsity_degree}, p = {p}, n = {n}\nOptimal by NMSE CV: {np.round(nmses_by_cv.mean(),3)}\nOptimal by NMSE LedoitWolf: {np.round(nmses_by_ledoitwolf.mean(),3)}')

            n_bins = 100
            axs[0].set_title(r'$\lambda^{(i)}$ distribution')
            axs[1].set_title(
                r'$\frac{\lambda^{(i)}}{\lambda^{(i)_\mathrm{max}}}$ distribution')
            plt.semilogy()
            plt.show(block=False)

            plt.savefig(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_c.pdf')

            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_NMSEs_by_cv.npy', nmses_by_cv)
            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_F1s_by_cv.npy', f1s_by_cv)

            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_NMSEs_by_S_inv.npy', nmses_by_S_inv)
            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_F1s_by_S_inv.npy', f1s_by_S_inv)

            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_NMSEs_by_ledoitwolf.npy', nmses_by_ledoitwolf)
            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_F1s_by_ledoitwolf.npy', f1s_by_ledoitwolf)

            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_NMSEs_by_mincovdet.npy', nmses_by_mincovdet)
            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_F1s_by_mincovdet.npy', f1s_by_mincovdet)

            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_NMSEs_by_oas.npy', nmses_by_oas)
            np.save(
                f'./optimal_glasso_results/{sparsity_degree}_p_{p}_n_{n}_testsamples{test_samples}_F1s_by_oas.npy', f1s_by_oas)
