import os
import argparse
import numpy as np
import torch
from torch.func import functional_call, vmap, grad
from pyhessian import hessian
import pickle
import glob
from svhn import SVHN
from utils import *


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class NestedDict(dict):
    def __missing__(self, key):
        self[key] = type(self)()
        return self[key]


def get_load(path, args):
    saved_dict = torch.load(path, map_location=device)
    # print("Loading the model from {}".format(path))
    # model
    if args.model == 'resnet':
        from resnet import ResNet18
        model = ResNet18()
    # if args.model == 'resnetwobn':
    #     from resnetwobn import ResNet18
    #     model = ResNet18()
    elif args.model == 'vgg':
        from vgg import vgg11
        model = vgg11()
    else:
        raise NotImplementedError()

    model.load_state_dict(saved_dict['model'])

    criterion = torch.nn.CrossEntropyLoss()
    model = model.to(device)
    lr = saved_dict['lr']
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.wd)
    # optimizer.load_state_dict(saved_dict['optimizer'])
    optimizer = 0
    # state = {'iter':i, 'lr':lr, 'ts_acc': test_acc, 'ts_loss': test_loss, 'tr_acc': train_acc, 'tr_loss': train_loss,
    # 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
    # args = saved_dict['args']

    return model, optimizer, lr, criterion


def get_gradient(model, loss, args):
    loss.backward()
    # gradients = [p.grad.data.cpu() for p in model.parameters() if p.grad is not None]
    # for param in model.parameters():
    #     gradients += param.grad.data.cpu()
    grads = []
    for param in model.parameters():
        grads.append(param.grad.detach().view(-1).cpu())
    grads = torch.cat(grads)

    return grads


def get_sample_grad(sample, target, model, criterion):
    params = {k: v.detach() for k, v in model.named_parameters()}
    buffers = {k: v.detach() for k, v in model.named_buffers()}

    def compute_loss(params, buffers, sample, target):
        batch = sample.unsqueeze(0)
        targets = target.unsqueeze(0)
        predictions = functional_call(model, (params, buffers), (batch,))
        loss = criterion(predictions, targets)
        return loss

    ft_compute_grad = torch.func.grad(compute_loss)
    ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
    ft_per_sample_grads = ft_compute_sample_grad(params, buffers, sample, target)

    det = 0
    grad_disper = 0

    for ft_per_sample_grad in ft_per_sample_grads.values():
        grad_var = torch.mean((ft_per_sample_grad - torch.mean(ft_per_sample_grad, dim=0))**2, dim=0).view(-1)
        mask = grad_var > 1e-6
        det += grad_var[mask].shape[0]
        # det *= torch.prod(grad_var[mask])
        # if train_mode:
        grad_disper += torch.sum(grad_var)

    return grad_disper, det

def get_logGNC(model, train_list, test_list, criterion):
    model.train()
    model.zero_grad()

    sample, target = train_list
    det_s, d_inc = get_sample_grad(sample, target, model, criterion)
    model.zero_grad()
    sample, target = test_list
    det_mu, _ = get_sample_grad(sample, target, model, criterion)
    model.zero_grad()

    return torch.log(torch.abs(det_mu - det_s)/det_s + 1), det_s, d_inc


def get_initmodel(args):
    if args.model == 'resnet':
        from resnet import ResNet18
        model = ResNet18()
        model.load_state_dict(torch.load('./init/resnet.pth'))
    # if args.model == 'resnetwobn':
    #     from resnetwobn import ResNet18
    #     model = ResNet18()
    #     model.load_state_dict(torch.load('./init/resnetwobn.pth'))
    elif args.model == 'vgg':
        from vgg import vgg11
        model = vgg11()
        model.load_state_dict(torch.load('./init/vgg.pth'))
    else:
        raise NotImplementedError()

    return model


def get_distance(model0, model):
    weight_norm = 0
    distance = 0
    for (param0, param) in zip(model0.parameters(), model.parameters()):
        weight_norm += (param.data.cpu() ** 2).sum()
        distance += ((param.data.cpu() - param0.data.cpu()) ** 2).sum()

    return weight_norm, distance


def get_hessian(model, train_list):
    model.eval()
    x, y = train_list
    # for j in range(len(train_list)):
    #     x, y = train_list[j]
    #     break
    # inputs, targets = inputs.to(device), targets.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    hessian_comp = hessian(model, criterion, data=(x, y), cuda=torch.cuda.is_available())
    trace = hessian_comp.trace()
    top_eigenvalues, _ = hessian_comp.eigenvalues()

    return np.mean(trace), top_eigenvalues[-1]


def get_acc(model, train_list, test_list, criterion, args):
    model.train()
    # train_list = dataset.getTrainGhostBatch(args.batchsize, args.batchsize, args.aug, True)
    model.zero_grad()
    train_loss, train_acc, train_grad = 0, 0, 0

    for j in range(len(train_list)):
        x, y = train_list[j]
        out = model(x)

        loss = criterion(out, y)
        # loss = loss.mean()
        # loss.backward()
        # if args.method == 'sgd':
        train_grad += get_gradient(model, loss, args)
        train_acc += accuracy(out, y).item()
        train_loss += loss.item()

    train_acc /= len(train_list)
    train_loss /= len(train_list)
    train_grad /= len(train_list)

    # evaluate
    # model.eval()
    model.zero_grad()
    test_loss, test_acc, test_grad = 0, 0, 0
    for x, y in test_list:
        out = model(x)
        loss = criterion(out, y)
        # if args.method == 'sgd':
        test_grad = get_gradient(model, loss, args)
        test_loss += loss.item()
        test_acc += accuracy(out, y).item()
    test_loss /= len(test_list)
    test_acc /= len(test_list)
    test_grad /= len(test_list)

    # grad_var = 0
    # if args.method == 'sgd':
    grad_var = torch.norm(train_grad - test_grad, 2.0)
    grad_norm = torch.norm(train_grad, 2.0)

    return train_acc, train_loss, test_loss, test_acc, grad_var, grad_norm
    # return


def get_results(n_seeds, dir_head, dataset, args):
    test_list = dataset.getTestList(1000, torch.cuda.is_available())
    # train_list = dataset.getTrainGhostBatch(50000, 1000, args.aug, torch.cuda.is_available())
    train_list = dataset.getTrainList(1000, torch.cuda.is_available())
    if args.gradtrace or args.hessian:
        small_train = dataset.getTrainBatch(100, torch.cuda.is_available())
        small_test = dataset.getTestBatch(100, torch.cuda.is_available())
    if args.distance:
        model0 = get_initmodel(args)
    results = NestedDict()  # indexing with n, epoch
    # dir_name = dir_head + f'_seed_{0}'
    # dir_path = os.path.join(args.logdir, dir_name)

    # pth_file = dir_path + '/*.pth.tar'
    # n_pth = len(glob.glob(pth_file))

    for i in range(n_seeds):
        print(f'seed number {i}')
        seed_id = f'_seed_{i}'
        dir_name = dir_head + seed_id
        dir_path = os.path.join(args.logdir, dir_name)
        if not os.path.exists(dir_path):
            print(f"Did not find results for {dir_name}")
            continue
        pth_file = dir_path + '/*.pth.tar'
        n_pth = len(glob.glob(pth_file))

        for j in range(n_pth):
            step = int(j * 5000)
            print(f'step number {step}')
            pth_name = f'iter-{step}.pth.tar'

            pth_dir = os.path.join(dir_path, pth_name)
            model, optimizer, lr, criterion = get_load(pth_dir, args)

            logTrace = 0
            weight_norm = 0
            distance = 0
            hess_tr = 0
            hess_eig = 0
            edge_sta = 0
            grad_dis, d_inc = 0, 0

            train_acc, train_loss, test_loss, test_acc, grad_var, grad_norm = get_acc(model, train_list, test_list,
                                                                                      criterion, args)
            if args.gradtrace:
                # small_train = dataset.getTrainGhostBatch(500, 500, args.aug, True)
                # small_test = dataset.getTestGhostBatch(500, 500, args.aug, True)
                logTrace, grad_dis, d_inc = get_logGNC(model, small_train, small_test, criterion)

            if args.distance:
                weight_norm, distance = get_distance(model0, model)

            if args.hessian:
                hess_tr, hess_eig = get_hessian(model, small_train)
                edge_sta = 2 / lr - hess_eig

            results[j][i] = {
                'train_acc': train_acc,
                'train_loss': train_loss,
                'test_loss': test_loss,
                'test_acc': test_acc,
                'acc_gap': np.abs(test_acc - train_acc),
                'loss_gap': np.abs(test_loss - train_loss),
                'grad_var': grad_var,
                'grad_norm': grad_norm,
                'log_trace': logTrace,
                'grad_dis': grad_dis,
                'd_inc': d_inc,
                'weight_norm': weight_norm,
                'distance': distance,
                'hessian_trace': hess_tr,
                'hessian_eigen': hess_eig,
                'egde_stability': edge_sta
            }

    return results


def main():
    parser = argparse.ArgumentParser()
    # parser.add_argument('--device', '-d', default='cuda', help='specifies the main device')
    # parser.add_argument('--exp_name', type=str, required=True)
    parser.add_argument('--method', type=str, default='sgd', choices=['ggdCov', 'sgd', 'sgdCov'])
    # parser.add_argument('--iters', type=int, default=int(1e5+1))
    # parser.add_argument('--schedule', type=int, nargs='+', default=[int(4e4), int(6e4)])
    parser.add_argument('--batchsize', type=int, default=100)  # 500 for ggdCov
    # parser.add_argument('--ghostsize', type=int, default=100)
    parser.add_argument('--lr', type=float, default=0.1)
    # parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--wd', type=float, default=5e-4)
    parser.add_argument('--aug', action='store_true', default=False)
    parser.add_argument('--randinit', action='store_true', default=False)
    parser.add_argument('--gradtrace', action='store_true', default=False)
    parser.add_argument('--distance', action='store_true', default=False)
    parser.add_argument('--hessian', action='store_true', default=False)
    parser.add_argument('--n-seeds', type=int, default=10)
    parser.add_argument('--model', type=str, default='vgg', choices=['vgg', 'resnet'])
    parser.add_argument('--datadir', type=str, default='/home/geothe9/datasets/SVHN/train25k_test70k')
    parser.add_argument('--logdir', type=str, default='../logs/SVHN')
    parser.set_defaults(parse=True)
    args = parser.parse_args()
    print(args)

    # results = NestedDict()  # indexing with n, epoch
    dir_head = '{}_{}_{}_lr_{}_wd_{}'.format(
        args.model,
        args.method,
        'rand' if args.randinit else 'fix',
        # 'aug' if args.aug else 'no-aug',
        args.lr, args.wd)

    dataset = SVHN(args.datadir)
    results = get_results(args.n_seeds, dir_head, dataset, args)
    # test_list = dataset.getTestList(10000, True)
    # train_list = dataset.getTrainGhostBatch(50000, 10000, args.aug, True)

    exp_name = args.model + args.method
    results_file_path = os.path.join(args.logdir, exp_name)
    try:
        if not os.path.isdir(results_file_path):
            os.makedirs(results_file_path)
    except OSError as err:
        print(err)

    output_file = os.path.join(results_file_path, 'results.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(results, f)


if __name__ == '__main__':
    main()