""" Training GLAD on our data. """
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score

from spodnet.GLAD.notebooks.glad import glad
from spodnet.GLAD.notebooks.glad_model import glad_model

torch.manual_seed(0)
# Use the same data as we use for the SpodNet experiments
precision_sparsity = 0.95

settings = [
    [100, 10],
    [100, 20],
    [100, 100],
    [100, 200],
    [100, 500],
]

Ls = [1]

epochs = 100

for setting in settings:
    D = setting[0]
    n = setting[1]
    for L in Ls:
        torch.manual_seed(0)

        lr_glad = 1e-2  # Out of the box

        INIT_DIAG = 0.1
        lambda_init = 1  # OUT OF THE BOX

        p = D

        USE_CUDA = False

        non_spd_percentage = []

        train_samples = 1000
        test_samples = 100
        train_batch_size = 10
        test_batch_size = test_samples

        train_matrices = torch.load(
            f'../data/train/p_{p}_n_{n}_density_{precision_sparsity}_size_{10_000}_random_state_{0}.pt')
        train_set = train_matrices[:train_samples]

        test_matrices = torch.load(
            f'../data/test/p_{p}_n_{n}_density_{precision_sparsity}_size_{10_000}_random_state_{1}.pt')
        test_set = test_matrices[:test_samples]

        train_loader = DataLoader(
            dataset=train_set,
            batch_size=train_batch_size,
            shuffle=True,
            num_workers=4
        )
        test_loader = DataLoader(
            dataset=test_set,
            batch_size=test_batch_size,
            shuffle=False,
            num_workers=4
        )

        # Initialize the model
        model_glad = glad_model(L=L,
                                theta_init_offset=0.1,
                                nF=3,
                                H=3,
                                USE_CUDA=USE_CUDA)

        optimizer_glad = Adam(model_glad.parameters(), lr=lr_glad)

        criterion_graph = nn.MSELoss()
        mse_loss_fn = torch.nn.MSELoss()
        mse_loss_fn_no_reduction = torch.nn.MSELoss(reduction='none')

        test_avg_Z_NMSE = []
        test_std_Z_NMSE = []

        test_avg_Theta_NMSE = []
        test_std_Theta_NMSE = []

        test_avg_Theta_sparsified_NMSE = []
        test_std_Theta_sparsified_NMSE = []

        test_avg_Theta_F1_losses = []
        test_std_Theta_F1_losses = []

        test_avg_Theta_sparsified_F1_losses = []
        test_std_Theta_sparsified_F1_losses = []

        test_avg_Theta_sparsities = []
        test_std_Theta_sparsities = []

        test_avg_Theta_sparsified_sparsities = []
        test_std_Theta_sparsified_sparsities = []

        test_Z_min_eigvals = []
        test_Theta_min_eigvals = []
        test_Theta_sparsified_min_eigvals = []

        for epoch in range(epochs):
            print(f"Epoch {epoch}")
            model_glad.train()
            glad_train_loss_list = np.array([])
            glad_individual_train_loss_list = np.array([])

            train_individual_NMSE_losses_list = torch.tensor([])
            for (i, (S, Theta_true, Sigma_true, X, alpha)) in enumerate(train_loader):
                train_rho_vals_list = []

                optimizer_glad.zero_grad()

                S = S.type(torch.FloatTensor)
                Theta_true = Theta_true.type(torch.FloatTensor)

                theta_pred, Z, glad_loss, _, Theta_list, Z_list = glad(S, Theta_true, model_glad, [
                    D, INIT_DIAG, lambda_init, L], criterion_graph, rho_vals_list=train_rho_vals_list)

                train_individual_NMSE_losses = (
                    (torch.linalg.matrix_norm(Z - Theta_true, ord='fro')**2) /
                    torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                train_individual_NMSE_losses_list = torch.cat(
                    (train_individual_NMSE_losses_list, train_individual_NMSE_losses.detach()))

                glad_loss.backward()
                optimizer_glad.step()

            # Testing the model
            model_glad.eval()

            test_individual_Z_NMSEs = torch.tensor([])
            test_individual_Theta_NMSEs = torch.tensor([])
            test_individual_Theta_sparsified_NMSEs = torch.tensor([])

            test_individual_NMSEs_Z = torch.tensor([])

            # To log the sparsity values
            test_individual_Theta_sparsities = torch.tensor([])
            test_individual_Theta_sparsified_sparsities = torch.tensor([])

            # To log the F1 scores
            test_individual_Theta_F1_losses = torch.tensor([])
            test_individual_Theta_sparsified_F1_losses = torch.tensor([])

            with torch.no_grad():
                for (S, Theta_true, Sigma_true, X, alpha) in test_loader:
                    test_rho_vals_list = []

                    S = S.type(torch.FloatTensor)
                    Theta_true = Theta_true.type(torch.FloatTensor)

                    (Theta,
                        Z,
                        _,
                        _,
                        test_Theta_list,
                        test_Z_list) = glad(S, Theta_true, model_glad, [
                            D, INIT_DIAG, lambda_init, L], criterion_graph, rho_vals_list=test_rho_vals_list)
                    Theta_sparsified = Theta.clone()
                    Theta_sparsified[Z == 0.] = 0.

                    # Computing metrics
                    for (T_t, T, T_sparsified, z) in zip(Theta_true, Theta, Theta_sparsified, Z):
                        # Computing sparsities
                        T_sparsity = 1 - (torch.count_nonzero(T) /
                                          (p**2))
                        T_sparsified_sparsity = 1 - (torch.count_nonzero(T_sparsified) /
                                                     (p**2))

                        test_individual_Theta_sparsities = torch.cat(
                            (test_individual_Theta_sparsities,
                                torch.tensor([T_sparsity]))
                        )
                        test_individual_Theta_sparsified_sparsities = torch.cat(
                            (test_individual_Theta_sparsified_sparsities,
                                torch.tensor([T_sparsified_sparsity]))
                        )

                        # Computing the F1 scores
                        T_t_support = T_t.numpy().copy()
                        T_t_support[T_t != 0.] = 1.

                        T_support = T.numpy().copy()
                        T_support[T_support != 0.] = 1.
                        T_individual_F1 = f1_score(T_t_support.flatten(),
                                                   T_support.flatten())
                        test_individual_Theta_F1_losses = torch.cat(
                            (test_individual_Theta_F1_losses, torch.tensor([T_individual_F1])))

                        T_sparsified_support = T_sparsified.numpy().copy()
                        T_sparsified_support[T_sparsified_support != 0.] = 1.
                        T_sparsified_individual_F1 = f1_score(T_t_support.flatten(),
                                                              T_sparsified_support.flatten())
                        test_individual_Theta_sparsified_F1_losses = torch.cat(
                            (test_individual_Theta_sparsified_F1_losses, torch.tensor([T_sparsified_individual_F1])))

                    # Computing the NMSEs
                    # We are training GLAD on Z, so first taking Z against Theta_true
                    test_individual_Z_NMSE_losses = (
                        torch.linalg.matrix_norm(Z - Theta_true, ord='fro')**2 /
                        torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                    test_individual_Z_NMSEs = torch.cat(
                        (test_individual_Z_NMSEs, test_individual_Z_NMSE_losses.detach()))

                    # Then, also Theta against Theta_true
                    test_individual_Theta_NMSE_losses = (
                        torch.linalg.matrix_norm(Theta - Theta_true, ord='fro')**2 /
                        torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                    test_individual_Theta_NMSEs = torch.cat(
                        (test_individual_Theta_NMSEs, test_individual_Theta_NMSE_losses.detach()))

                    # Finally, also Theta_sparsified against Theta_true
                    test_individual_Theta_sparsified_NMSE_losses = (
                        torch.linalg.matrix_norm(Theta_sparsified - Theta_true, ord='fro')**2 /
                        torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
                    test_individual_Theta_sparsified_NMSEs = torch.cat(
                        (test_individual_Theta_sparsified_NMSEs, test_individual_Theta_sparsified_NMSE_losses.detach()))

                # Logging the metrics after passing all test matrices
                print(
                    f"Test avg Z NMSE : {torch.mean(test_individual_Z_NMSEs)}")
                print(
                    f"Test avg Theta NMSE : {torch.mean(test_individual_Theta_NMSEs)}")
                print(
                    f"Test avg Theta_sparsified NMSE : {torch.mean(test_individual_Theta_sparsified_NMSEs)}")

                # Logging NMSEs
                test_avg_Z_NMSE.append(torch.mean(
                    test_individual_Z_NMSEs).item())
                test_std_Z_NMSE.append(torch.std(
                    test_individual_Z_NMSEs).item())

                test_avg_Theta_NMSE.append(torch.mean(
                    test_individual_Theta_NMSEs).item())
                test_std_Theta_NMSE.append(torch.std(
                    test_individual_Theta_NMSEs).item())

                test_avg_Theta_sparsified_NMSE.append(torch.mean(
                    test_individual_Theta_sparsified_NMSEs).item())
                test_std_Theta_sparsified_NMSE.append(torch.std(
                    test_individual_Theta_sparsified_NMSEs).item())

                # Logging F1 scores
                test_avg_Theta_F1_losses.append(torch.mean(
                    test_individual_Theta_F1_losses).item())
                test_std_Theta_F1_losses.append(torch.std(
                    test_individual_Theta_F1_losses).item())

                test_avg_Theta_sparsified_F1_losses.append(torch.mean(
                    test_individual_Theta_sparsified_F1_losses).item())
                test_std_Theta_sparsified_F1_losses.append(torch.std(
                    test_individual_Theta_sparsified_F1_losses).item())

                # Logging sparsities
                test_avg_Theta_sparsities.append(torch.mean(
                    test_individual_Theta_sparsities).item())
                test_std_Theta_sparsities.append(torch.std(
                    test_individual_Theta_sparsities).item())

                test_avg_Theta_sparsified_sparsities.append(torch.mean(
                    test_individual_Theta_sparsified_sparsities).item())
                test_std_Theta_sparsified_sparsities.append(torch.std(
                    test_individual_Theta_sparsified_sparsities).item())

                # Logging the smallest eigenvalues
                test_Theta_min_eigvals.append(
                    torch.linalg.eigvalsh(Theta).min(dim=1)[0])

                test_Z_min_eigvals.append(
                    torch.linalg.eigvalsh(Z).min(dim=1)[0])
                test_Theta_sparsified_min_eigvals.append(
                    torch.linalg.eigvalsh(Theta_sparsified).min(dim=1)[0])

                # Logging the non-SPD percentage after sparsification
                non_spd_percentage.append(len(torch.unique(
                    torch.where(torch.linalg.eigvalsh(Theta_sparsified) <= 0.)[0])))
                print(non_spd_percentage[-1])

        # Saving NMSEs
        torch.save(
            test_avg_Z_NMSE, f"./glad_results/test_avg_Z_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(
            test_std_Z_NMSE, f"./glad_results/test_std_Z_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        torch.save(test_avg_Theta_NMSE,
                   f"./glad_results/test_avg_Theta_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_NMSE,
                   f"./glad_results/test_std_Theta_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        torch.save(test_avg_Theta_sparsified_NMSE,
                   f"./glad_results/test_avg_Theta_sparsified_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_sparsified_NMSE,
                   f"./glad_results/test_std_Theta_sparsified_NMSE_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        # Saving F1 scores
        torch.save(test_avg_Theta_F1_losses,
                   f"./glad_results/test_avg_Theta_F1_losses_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_F1_losses,
                   f"./glad_results/test_std_Theta_F1_losses_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        torch.save(test_avg_Theta_sparsified_F1_losses,
                   f"./glad_results/test_avg_Theta_sparsified_F1_losses_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_sparsified_F1_losses,
                   f"./glad_results/test_std_Theta_sparsified_F1_losses_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        # Saving sparsities
        torch.save(test_avg_Theta_sparsities,
                   f"./glad_results/test_avg_Theta_sparsities_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_sparsities,
                   f"./glad_results/test_std_Theta_sparsities_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        torch.save(test_avg_Theta_sparsified_sparsities,
                   f"./glad_results/test_avg_Theta_sparsified_sparsities_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_std_Theta_sparsified_sparsities,
                   f"./glad_results/test_std_Theta_sparsified_sparsities_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        # Saving the smallest eigenvalues
        torch.save(test_Theta_min_eigvals,
                   f"./glad_results/test_Theta_min_eigvals_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_Z_min_eigvals,
                   f"./glad_results/test_Z_min_eigvals_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")
        torch.save(test_Theta_sparsified_min_eigvals,
                   f"./glad_results/test_Theta_sparsified_min_eigvals_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt")

        # Saving the non-SPD percentage after sparsification
        torch.save(
            non_spd_percentage, f'./glad_results/non_spd_percentage_D{D}_n{n}_sparsity{precision_sparsity}_L{L}_return_Z.pt')
