import argparse
import tabulate
import time
import numpy as np
import os
import torch

from curvature import data, models, losses, utils
from curvature.methods.swag import SWAG


def compute_loss_stats(
        dataset: str,
        data_path: str,
        model: str,
        checkpoint_path: tuple,
        use_test: bool = True,
        batch_size: int = 128,
        num_workers: int = 4,
        num_subsamples: int = None,
        subsample_seed: int = None,
        stats_batch: int = 128,
        swag: bool = True,
        save_path: str = None,
        seed: int = None,
        curvature_matrix: str = 'hessian',
        device: str = 'cuda',
) -> dict:
    """
    Compute the loss statistics, provided a checkpoint of the model saved using train_network.py.

    Parameters
    ----------
    dataset: str: ['CIFAR10', 'CIFAR100', 'MNIST', 'ImageNet32'*]: the dataset on which you would like to train the
    model. For ImageNet 32, we use the downsampled 32 x 32 Full ImageNet dataset. We do not provide download due to
    the proprietary issues, and please drop the data of ImageNet 32 in 'data/' folder

    data_path: str: the path string of the dataset

    model: str: the neural network architecture you would like to train. All available models are listed under 'models'/
    Example: VGG16BN, PreResNet110 (Preactivated ResNet - 110 layers)

    checkpoint_path: str: the path string to the checkpoints generated by train_network, which contains the state_dict
    of the network and the optimizer.

    use_test: bool: if True, you will test the model on the test set. If not, a portion of the training data will be
    assigned as the validation set.

    batch_size: int: the minibatch size

    num_workers: int: number of workers for the dataloader

    num_subsamples: int: Number of subsamples to draw randomly from the training dataset. If None, the entire dataset
    will be used.

    subsample_seed: int: the Pseudorandom number seed for subsample draw from above.

    stats_batch: int: the number of samples to run loss stats. Higher the stats_batch, higher the computation speed but
    at the same time higher the VRAM/RAM demand.

    swag: whether to use Stochastic Weight Averaging (Gaussian)

    save_path: if provided, the loss stats dictionary will be saved an additional copy as numpy array in the specified
    path.

    seed: if not None, a manual seed for the pseudo-random number generation will be used.

    curvature_matrix

    device: ['cpu', 'cuda']: the device on which the model and all computations are performed. Strongly recommend 'cuda'
    for GPU accleration in CUDA-enabled Nvidia Devices

    Returns:
    A dictionary, containing the following elements:

    -------
        'train_loss', 'train_acc', 'test_loss', 'test_acc', (Literal meaning)
        'loss_mean', 'loss_var': mean and variance of test losses
        'grad_mean_norm_sq', 'grad_var',: squared mean and variance of the *gradient*
        'hess_mean_norm_sq', 'hess_var', 'hess_mu', squared mean, variance and mean of *Hessian*
        'delta', 'alpha': Hessian confidence
        'weight_norm_l2', 'weight_norm_linf': the L2 and L-inf norms of the weights
    """
    if device == 'cuda':
        if not torch.cuda.is_available():
            device = 'cpu'

    torch.backends.cudnn.benchmark = True
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    print('Using model %s' % model)
    model_cfg = getattr(models, model)

    full_datasets, _ = data.datasets(
        dataset,
        data_path,
        transform_train=model_cfg.transform_train,
        transform_test=model_cfg.transform_test,
        use_validation=not use_test,
    )

    full_loader = torch.utils.data.DataLoader(
        full_datasets['train'],
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    datasets, num_classes = data.datasets(
        dataset,
        data_path,
        transform_train=model_cfg.transform_test,
        transform_test=model_cfg.transform_test,
        use_validation=not use_test,
        train_subset=num_subsamples,
        train_subset_seed=subsample_seed,
    )

    loader = torch.utils.data.DataLoader(
        datasets['train'],
        batch_size=stats_batch,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    batch_loader = torch.utils.data.DataLoader(
        datasets['train'],
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        datasets['test'],
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    print('Preparing model')
    print(*model_cfg.args, dict(**model_cfg.kwargs))
    if not swag:
        model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
        model.to(device)
        swag_model = None
    else:
        swag_model = SWAG(model_cfg.base, subspace_type='random',
                      *model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)
        swag_model.to(device)
        model = None

    criterion = losses.cross_entropy

    stat_labels = [
        'train_loss', 'train_acc', 'test_loss', 'test_acc',
        'loss_mean', 'loss_var',
        'grad_mean_norm_sq', 'grad_var',
        'hess_mean_norm_sq', 'hess_var', 'hess_mu',
        'delta', 'alpha',
        'weight_norm_l2', 'weight_norm_linf'
    ]

    # Is args.ckpt a directory?
    if len(checkpoint_path) == 1 and os.path.isdir(checkpoint_path[0]):
        checkpoint_path = []
        for filename in os.listdir(checkpoint_path[0]):
            if filename.endswith(".pt"):
                checkpoint_path.append(os.path.join(checkpoint_path[0], filename))
        print("File list: ", checkpoint_path)

    K = len(checkpoint_path)
    stat_dict = {
    label: np.zeros(K) for label in stat_labels
    }

    columns = ['#'] + stat_labels + ['time']

    for i, ckpt_path in enumerate(checkpoint_path):
        start_time = time.time()
        print('Loading %s' % checkpoint_path)
        checkpoint = torch.load(ckpt_path)
        if not swag:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            swag_model.load_state_dict(checkpoint['state_dict'], strict=False)
            swag_model.set_swa()
            model = swag_model.base_model

        utils.bn_update(full_loader, model)
        train_res = utils.eval(full_loader, model, criterion)
        test_res = utils.eval(test_loader, model, criterion)

        stat_dict['train_loss'][i] = train_res['loss']
        stat_dict['train_acc'][i] = train_res['accuracy']
        stat_dict['test_loss'][i] = test_res['loss']
        stat_dict['test_acc'][i] = test_res['accuracy']

        loss_stats = utils.loss_stats(loader, model, criterion, cuda=True, verbose=False,
                                      bn_train_mode=True, curvature_matrix=curvature_matrix)
        w = torch.cat([param.detach().cpu().view(-1) for param in model.parameters()])
        w_l2_norm = torch.norm(w).numpy()
        w_linf_norm = torch.norm(w, float('inf')).numpy()

        for label, value in loss_stats.items():
            stat_dict[label][i] = value
        stat_dict['weight_norm_l2'] = w_l2_norm
        stat_dict['weight_norm_linf'] = w_linf_norm
        ckpt_time = time.time() - start_time

        values = ['%d/%d' % (i + 1, K)] + [stat_dict[label][i] for label in stat_labels] + [ckpt_time]

        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='0.2g')
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
        print(table)

    stat_dict['train_err'] = 100.0 - stat_dict['train_acc']
    stat_dict['test_err'] = 100.0 - stat_dict['test_acc']

    num_parameters = sum([p.numel() for p in model.parameters()])

    if save_path is not None:
        np.savez(
            save_path,
            checkpoints=checkpoint_path,
            num_parameters=num_parameters,
            **stat_dict
         )
    return stat_dict
