import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import os
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.autograd import grad

from as_models.resnet_cifar import WideResNetCIFAR, ResNetV2CIFAR
from as_models.model_utils import save_checkpoint, load_model
from as_utils.tools import accuracy, AverageMeter
from as_data_reader.dataset_reader import get_torchvision_dataset

from noise_self_distil.args import args


def get_model_id():
    model_id = f"{args.teacher_model.split('/')[-2]}-{args.dataset}-{args.subset}-" \
               f"{args.jac_type}-traces-{args.part_index}.npz"
    return model_id


def compute_cov(pool_index, part=10):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_set, val_set = get_torchvision_dataset(0, args.dataset)
    if args.subset == 'train':
        data_set = train_set
        data_set.data = data_set.data[:50000]
        # data_loader = create_reader_one(train_set, 1, False)
    else:
        data_set = val_set
        # data_loader = create_reader_one(val_set, 1, False)

    # create model
    if 'wrn' in args.network:  # like wrn_28_10
        layers = int(args.network.split('_')[-2])
        widen_factor = int(args.network.split('_')[-1])
        model = WideResNetCIFAR(layers, args.dataset == 'cifar100' and 100 or 10,
                                widen_factor, init=args.init)
    elif 'resnet_v2' in args.network:  # like resnet_v2_110
        layers = int(args.network.split('resnet_v2')[-1].split('_')[1])
        model = ResNetV2CIFAR(layers, args.dataset == 'cifar100' and 100 or 10, init=args.init)
    else:
        raise ValueError('Not suppported {} yet'.format(args.network))

    load_model(model, args.teacher_model)
    model = model.to(device)
    model.eval()

    batch_time = AverageMeter()

    # data_set.data = data_set.data[:20]
    traces = np.zeros(len(data_set))
    traces_0 = np.zeros(len(data_set))
    traces_gt = np.zeros(len(data_set))
    traces_pred = np.zeros(len(data_set))

    print(args.dataset, f'with {len(data_set)} samples.')

    assert len(data_set) % part == 0
    gap = len(data_set) // part
    i_start = pool_index * gap
    i_end = (pool_index + 1) * gap

    computed_num = 0
    for i, (input, target) in enumerate(data_set):
        if not i_start <= i < i_end:
            continue

        end = time.time()

        input = input.to(device, non_blocking=True).unsqueeze_(0)
        logits, features = model(input)
        if args.jac_type == 'log':
            probas = - torch.log(torch.softmax(logits, -1))

        pred_class = torch.argmax(logits, dim=-1)
        batch_size, num_classes = logits.shape

        trace = torch.tensor(0.0, dtype=torch.float, device=device)

        # works but slower than a loop.
        # def f(*params):
        #     return logits
        # g = torch.autograd.functional.jacobian(f, tuple(model.parameters()))
        # print(type(g[0]))
        # print(len(g))

        for k in range(num_classes):
            if args.jac_type == 'log':
                grads = grad(probas[0, k], list(model.parameters()), retain_graph=(k < num_classes - 1))
            else:
                grads = grad(logits[0, k], list(model.parameters()), retain_graph=(k<num_classes-1))
            grads_vec = torch.cat([g.view(-1) for g in grads])
            norm = (grads_vec**2).sum()
            trace += norm

            # print(k, norm)

            if k == 0:
                traces_0[i] = norm.item()
            if k == target:
                traces_gt[i] = norm.item()
            if k == pred_class:
                traces_pred[i] = norm.item()

        traces[i] = trace.item()

        batch_time.update(time.time() - end)

        computed_num += 1

        print(f"{pool_index}:{i},{computed_num}/{gap}: "
              f"all classes: {np.sum(traces) / computed_num: .2f}, first class: {np.mean(traces_0): .2f}, "
              f"gt: {np.mean(traces_gt): .2f}, pred: {np.mean(traces_pred): .2f}")
        print(batch_time.avg, batch_time.val)

    traces_dict = {"all": traces, "0": traces_0, "gt": traces_gt, "pred": traces_pred}
    return traces_dict


if __name__ == '__main__':
    print(args)

    results = compute_cov(args.part_index)

    collected_traces = results

    save_path = os.path.join(f'{args.outdir}', get_model_id())
    os.makedirs(args.outdir, exist_ok=True)
    np.savez(save_path, **collected_traces)

    print(f"all classes: {np.mean(collected_traces['all']): .2f}, "
          f"first class: {np.mean(collected_traces['0']): .2f}, "
          f"gt: {np.mean(collected_traces['gt']): .2f}, "
          f"pred: {np.mean(collected_traces['pred']): .2f}")
