import argparse
import socket
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np
from sklearn.metrics import f1_score

from spodnet import SpodNet


def save_data(avg_train,
              avg_test,
              std_train,
              std_test,
              sparsity,
              avg_train_f1_scores,
              std_train_f1_scores,
              avg_test_f1_scores,
              std_test_f1_scores,
              avg_logdet_losses,
              std_logdet_losses,
              avg_test_smallest_eigvals,
              std_test_smallest_eigvals,
              avg_test_largest_eigvals,
              std_test_largest_eigvals,
              path,
              net):
    print("Saving data...")
    torch.save(avg_train, path + f'avg_train_losses.pt')
    torch.save(avg_test, path + f'avg_test_losses.pt')
    torch.save(std_train, path + f'std_train_losses.pt')
    torch.save(std_test, path + f'std_test_losses.pt')
    torch.save(sparsity, path + f'train_sparsity.pt')
    torch.save(avg_train_f1_scores, path + f'avg_train_f1_scores.pt')
    torch.save(std_train_f1_scores, path + f'std_train_f1_scores.pt')
    torch.save(avg_test_f1_scores, path + f'avg_test_f1_scores.pt')
    torch.save(std_test_f1_scores, path + f'std_test_f1_scores.pt')
    torch.save(avg_logdet_losses, path + f'avg_logdet_losses.pt')
    torch.save(std_logdet_losses, path + f'std_logdet_losses.pt')
    torch.save(avg_test_smallest_eigvals, path +
               f'avg_test_smallest_eigvals.pt')
    torch.save(std_test_smallest_eigvals, path +
               f'std_test_smallest_eigvals.pt')
    torch.save(avg_test_largest_eigvals, path +
               f'avg_test_largest_eigvals.pt')
    torch.save(std_test_largest_eigvals, path +
               f'std_test_largest_eigvals.pt')

    print("Data is saved.")
    print("Saving trained model...")
    torch.save(net, path + f'trained_model.pt')
    print("Model is saved.")
    return avg_train, avg_test, std_train, std_test, path


def get_off_diag(M):
    res = M.clone()
    res.diagonal(dim1=-1, dim2=-2).zero_()
    return res


def train_spodnet(train_samples,
                  test_samples,
                  n,
                  p,
                  train_batch_size,
                  test_batch_size,
                  precision_sparsity,
                  dataloader_shuffle,
                  K,
                  training_seed,
                  epochs,
                  lr,
                  scheduler_patience,
                  scheduler_factor,
                  scheduler_min_lr,
                  loss_discount_factor,
                  learning_mode,
                  training_type,
                  zeta):

    # Datasets
    train_matrices = torch.load(
        f'./data/train/p_{p}_n_{n}_density_{precision_sparsity}_size_{10_000}_random_state_{0}.pt')
    # train_set = random.sample(train_matrices, train_samples)
    train_set = train_matrices[:train_samples]

    test_matrices = torch.load(
        f'./data/test/p_{p}_n_{n}_density_{precision_sparsity}_size_{10_000}_random_state_{1}.pt')
    # test_set = random.sample(test_matrices, test_samples)
    test_set = test_matrices[:test_samples]

    # Dataloader
    train_loader = DataLoader(
        dataset=train_set,
        batch_size=train_batch_size,
        shuffle=1,
        num_workers=4
    )

    test_loader = DataLoader(
        dataset=test_set,
        batch_size=test_batch_size,
        shuffle=0,
        num_workers=4
    )

    device = "cpu"
    print(f"Training on {device}.")

    torch.manual_seed(training_seed)
    print(f"Using zeta = {zeta}.")
    net = SpodNet(K=K,
                  p=p,
                  layer_type=f'{learning_mode}_masks',
                  zeta=zeta,
                  device=device)

    net.to(device)

    nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print(f"Network has {nb_params} learnable parameters.")

    optimizer = Adam(net.parameters(), lr=lr)

    # Lists to log results
    avg_logdet_losses = []
    std_logdet_losses = []

    avg_test_smallest_eigvals = []
    std_test_smallest_eigvals = []

    avg_test_largest_eigvals = []
    std_test_largest_eigvals = []

    avg_train_losses = []
    std_train_losses = []
    train_sparsity = []
    avg_train_f1_scores_list = []
    std_train_f1_scores_list = []
    avg_test_f1_scores_list = []
    std_test_f1_scores_list = []
    avg_test_losses = []
    std_test_losses = []

    patience = 10

    # Saving directory
    path = f"./{learning_mode}/tests_UBG/train_samples_{train_samples}/test_samples_{test_samples}/p_{p}/n_{n}/train_batch_size_{train_batch_size}/test_batch_size_{test_batch_size}/precision_sparsity_{precision_sparsity}/dataloader_shuffle_{dataloader_shuffle}/K_{K}/training_seed_{training_seed}/epochs_{epochs}/lr_{lr}/scheduler_patience_{scheduler_patience}/scheduler_factor_{scheduler_factor}/scheduler_min_lr_{scheduler_min_lr}/loss_discount_factor_{loss_discount_factor}/"
    Path(path).mkdir(parents=True, exist_ok=True)

    # Training loop
    print(
        f"========== Training masked network: batch_size={train_batch_size}, K={K}. ========== ")
    patience_ctr = 0

    mse = torch.nn.MSELoss()
    for epoch in range(epochs):

        print(f"K={K} [{epoch}/{epochs}]")

        net.train()

        train_individual_NMSE_losses_list = torch.tensor([])
        train_individual_f1_losses = torch.tensor([])

        batch_logdet_losses = []

        for i, (S, Theta_true, Sigma_true, X, alpha) in enumerate(train_loader):
            Theta_true = Theta_true.to(device)

            pred, pred_list = net(S)

            off_diag = get_off_diag(pred)

            logdet_loss = (-torch.slogdet(pred)
                           [-1] + torch.bmm(pred, S).diagonal(dim1=-2, dim2=-1).sum(-1))

            train_individual_batch_losses = (
                torch.linalg.matrix_norm(pred - Theta_true, ord='fro')**2)

            torch.mean(train_individual_batch_losses).backward()
            optimizer.step()
            optimizer.zero_grad()

            sparsity_degree = (1 -
                               torch.count_nonzero(torch.round(pred, decimals=10)) / (train_batch_size*p*p))

            batch_logdet_losses.append(torch.mean(logdet_loss).item())
            # Logging the NMSEs
            train_individual_NMSE_losses = (
                train_individual_batch_losses /
                torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
            train_individual_NMSE_losses_list = torch.cat(
                (train_individual_NMSE_losses_list, train_individual_NMSE_losses.detach()))

            # Logging the F1 scores
            for (T_t, T) in zip(Theta_true, pred):
                T_t_support = T_t.detach().numpy().copy()
                T_t_support[T_t != 0.] = 1.
                T_support = torch.round(
                    T, decimals=7).detach().numpy().copy()
                T_support[T_support != 0.] = 1.
                individual_f1 = f1_score(T_t_support.flatten(),
                                         T_support.flatten())
                train_individual_f1_losses = torch.cat(
                    (train_individual_f1_losses, torch.tensor([individual_f1])))

        avg_train_losses.append(torch.mean(
            train_individual_NMSE_losses_list).item())
        std_train_losses.append(
            train_individual_NMSE_losses_list.std().item())
        train_sparsity.append(sparsity_degree)
        avg_train_f1_scores_list.append(train_individual_f1_losses.mean())
        std_train_f1_scores_list.append(train_individual_f1_losses.std())

        avg_logdet_losses.append(np.array(batch_logdet_losses).mean())
        std_logdet_losses.append(np.array(batch_logdet_losses).std())

        # Evaluate model on test data
        net.eval()
        test_individual_NMSEs = torch.tensor([])
        test_individual_f1_losses = torch.tensor([])

        batch_smallest_eigvals = []

        with torch.no_grad():
            # We typically do full-batch on test set
            for (S, Theta_true, _, _, _) in test_loader:
                S = S.to(device)

                Theta_true = Theta_true.to(device)

                pred, pred_list = net(S)

            # Logging the NMSEs
            test_individual_NMSE_losses = (
                torch.linalg.matrix_norm(pred - Theta_true, ord='fro')**2 /
                torch.linalg.matrix_norm(Theta_true, ord='fro')**2)
            test_individual_NMSEs = torch.cat(
                (test_individual_NMSEs, test_individual_NMSE_losses.detach()))

            # Logging the F1 scores
            for (T_t, T) in zip(Theta_true, pred):
                T_t_support = T_t.detach().numpy().copy()
                T_t_support[T_t != 0.] = 1.
                T_support = torch.round(
                    T, decimals=10).detach().numpy().copy()
                T_support[T_support != 0.] = 1.
                individual_f1 = f1_score(T_t_support.flatten(),
                                         T_support.flatten())
                test_individual_f1_losses = torch.cat(
                    (test_individual_f1_losses, torch.tensor([individual_f1])))

        avg_test_losses.append(torch.mean(
            test_individual_NMSEs).item())
        std_test_losses.append(
            test_individual_NMSEs.std().item())
        avg_test_f1_scores_list.append(test_individual_f1_losses.mean())
        std_test_f1_scores_list.append(test_individual_f1_losses.std())

        avg_test_smallest_eigvals.append(torch.linalg.eigvalsh(
            pred).min(axis=1)[0])
        std_test_smallest_eigvals.append(
            torch.linalg.eigvalsh(pred).min(axis=1)[0].std().item())

        avg_test_largest_eigvals.append(torch.linalg.eigvalsh(
            pred).max(axis=1)[0])
        std_test_largest_eigvals.append(
            torch.linalg.eigvalsh(pred).max(axis=1)[0].std().item())

        print(f'Avg train logdet: {avg_logdet_losses[-1]}')
        print(
            f"Avg NMSE train loss : {torch.mean(train_individual_NMSE_losses_list)}")
        print(
            f"Avg F1 train score : {torch.mean(train_individual_f1_losses)}")
        print(
            f"Avg NMSE test loss : {torch.mean(test_individual_NMSEs)}")
        print(
            f"Avg F1 test score : {torch.mean(test_individual_f1_losses)}")

    print("========== Training masked network complete. ========== ")

    return save_data(avg_train_losses,
                     avg_test_losses,
                     std_train_losses,
                     std_test_losses,
                     train_sparsity,
                     avg_train_f1_scores_list,
                     std_train_f1_scores_list,
                     avg_test_f1_scores_list,
                     std_test_f1_scores_list,
                     avg_logdet_losses,
                     std_logdet_losses,
                     avg_test_smallest_eigvals,
                     std_test_smallest_eigvals,
                     avg_test_largest_eigvals,
                     std_test_largest_eigvals,
                     path,
                     net)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Train SpodNet')
    parser.add_argument('-train_samples',
                        '--train_samples',
                        nargs='?',
                        required=True,
                        type=int,
                        help='Number of samples to train on.')
    parser.add_argument('-test_samples',
                        '--test_samples',
                        nargs='?',
                        default="",
                        type=int,
                        help='Number of samples to test on.',
                        required=True)
    parser.add_argument('-p',
                        '--dimension_of_data_samples',
                        nargs='?',
                        default="",
                        type=int,
                        help='The dimension of samples to use to compute the empirical covariance matrix.',
                        required=True)
    parser.add_argument('-n',
                        '--number_of_data_samples',
                        nargs='?',
                        default="",
                        type=int,
                        help='The number of samples to use to compute the empirical covariance matrix.',
                        required=True)
    parser.add_argument('-train_batch_size',
                        '--train_batch_size',
                        nargs='?',
                        type=int,
                        help='Batch size for training dataloader.',
                        required=True)
    parser.add_argument('-test_batch_size',
                        '--test_batch_size',
                        nargs='?',
                        type=int,
                        help='Batch size for testing dataloader.',
                        required=True)
    parser.add_argument('-precision_sparsity',
                        '--precision_sparsity',
                        nargs='?',
                        type=float,
                        help='Sparsity degree of true precision matrices.',
                        required=True)
    parser.add_argument('-dataloader_shuffle',
                        '--dataloader_shuffle',
                        nargs='?',
                        default=1,
                        type=int,
                        help='Whether to shuffle data with dataloader or not.',
                        )
    parser.add_argument('-K',
                        '--K',
                        nargs='?',
                        type=int,
                        help='Number of unrolled iterations.',
                        required=True)
    parser.add_argument('-training_seed',
                        '--training_seed',
                        nargs='?',
                        type=int,
                        default=101,
                        help='Trainign seed.',
                        )
    parser.add_argument('-epochs',
                        '--epochs',
                        nargs='?',
                        type=int,
                        help='Number of Adam epochs to train for.',
                        required=True)
    parser.add_argument('-lr',
                        '--lr',
                        nargs='?',
                        type=float,
                        help='Learning rate of Adam.',
                        required=True)
    parser.add_argument('-scheduler_patience',
                        '--scheduler_patience',
                        nargs='?',
                        type=float,
                        help='Scheduler patience.',
                        default=10_000,
                        )
    parser.add_argument('-scheduler_factor',
                        '--scheduler_factor',
                        nargs='?',
                        type=float,
                        default=0.8,
                        help='Scheduler factor.',
                        )
    parser.add_argument('-scheduler_min_lr',
                        '--scheduler_min_lr',
                        nargs='?',
                        type=float,
                        default=1e-4,
                        help='Scheduler minimal learning rate.',
                        )
    parser.add_argument('-loss_discount_factor',
                        '--loss_discount_factor',
                        nargs='?',
                        type=float,
                        default=1,
                        help='Discount factor for the loss.',
                        )
    parser.add_argument('-learning_mode',
                        '--learning_mode',
                        nargs='?',
                        help='UBG, PNP or E2E.',
                        required=True)
    parser.add_argument('-training_type',
                        '--training_type',
                        nargs='?',
                        help='supervised or unsupervised.',
                        default='supervised',
                        )
    parser.add_argument('-zeta',
                        '--zeta',
                        nargs='?',
                        type=float,
                        default=1,
                        help='zeta value in normalization',
                        )

    print(socket.gethostname())

    args = parser.parse_args()

    (avg_train_losses,
     avg_test_losses,
     std_train_losses,
     std_test_losses,
     path) = train_spodnet(args.train_samples,
                           args.test_samples,
                           args.number_of_data_samples,  # n
                           args.dimension_of_data_samples,  # p
                           args.train_batch_size,
                           args.test_batch_size,
                           args.precision_sparsity,
                           args.dataloader_shuffle,
                           args.K,
                           args.training_seed,
                           args.epochs,
                           args.lr,
                           args.scheduler_patience,
                           args.scheduler_factor,
                           args.scheduler_min_lr,
                           args.loss_discount_factor,
                           args.learning_mode,
                           args.training_type,
                           args.zeta)
