'''
Copyright (C) 2010-2021 Alibaba Group Holding Limited.

This file is modified from:
https://github.com/VITA-Group/TENAS
'''

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from torch import nn
import torch.nn.functional as F
import global_utils, argparse, ModelLoader, time

from PlainNet.basic_blocks import RELU

def cross_entropy(logit, target):
    # target must be one-hot format!!
    prob_logit = F.softmax(logit, dim=1)
    target_logit = F.softmax(target, dim=1)
    loss = -(target_logit * prob_logit).sum(dim=1).mean()
    return loss


def get_ntk_n_zen(xloader, vloader, networks, recalbn=0, train_mode=False, num_batch=None, batch_size=None, image_size=None, gpu=None, mixup_gamma=1e-2, num_classes=100):
    device = torch.cuda.current_device()
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads_x = [[] for _ in range(len(networks))]
    targets_x_onehot_mean, targets_y_onehot_mean = [], []
    grads_y = [[] for _ in range(len(networks))]

    '''
    for i in range(num_batch):
        inputs = torch.randn((batch_size, 3, image_size, image_size), device=device)
        targets = torch.randint(0, num_classes, (64,))
    '''
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)

        # inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_x[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()
    grads_x = [torch.stack(_grads, 0) for _grads in grads_x]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads_x]
    score_1, score_2 = [], []
    for ntk in ntks:
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] + 2 / torch.log10(eigenvalues[-1])).item(), copy=True))
        score_1.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] * (1 + 1 / torch.log10(eigenvalues[-1]))).item(), copy=True))
        score_2.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        

    # Val / Test set
    for i, (inputs, targets) in enumerate(vloader):
        if num_batch > 0 and i >= num_batch: break
        input_1 = inputs[:8, :]
        size = (8,) + inputs.size()[1:]
        input_2 = torch.randn(size)
        input_mix = input_1 + mixup_gamma * input_2
        inputs = torch.cat((input_1, input_mix), dim=0)
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_y_onehot_mean.append(targets_onehot_mean)
        targets_y_onehot_mean = torch.cat(targets_y_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_y[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()
    grads_y = [torch.stack(_grads, 0) for _grads in grads_y]
    score_3, score_4 = [], []

    for net_idx in range(len(networks)):
        # _ntk_yx = torch.einsum('nc,mc->nm', [grads_y, grads_x])
        _ntk_yx = [torch.einsum('nc,mc->nm', [_grads_y, _grads_x]) for _grads_y, _grads_x in zip(grads_y, grads_x)]
        PY = torch.einsum('jk,kl,lm->jm', _ntk_yx[0], torch.inverse(ntks[0]), targets_x_onehot_mean)
        # score_3.append(np.nan_to_num((torch.sum(torch.abs(PY[:8, :] - PY[8:])) / 8 * -1).item(), copy=True))
        # score_3.append(-1 * ((PY[32:] - PY[:32, :])**2).sum(1).mean(0).item())
        # score_3.append(-1 * ((PY[8:] - PY[:8, :])**2).sum(1).mean(0).item())
        score_3.append(-1 * cross_entropy(PY[8:], PY[:8, :]).item())
        score_4.append(((PY[:8, :] - targets_y_onehot_mean[:8, :])**2).sum(1).mean(0).item())

    return score_2, score_3, score_4


def compute_NTK_score(xloader, vloader, gpu, model, resolution, batch_size):
    ntk_score_1, ntk_score_2, ntk_score_3 = get_ntk_n_zen(xloader, vloader, [model], recalbn=0, train_mode=True, num_batch=1, batch_size=batch_size, image_size=resolution, gpu=gpu)
    ntk_score_1, ntk_score_2, ntk_score_3 = ntk_score_1[0], ntk_score_2[0], ntk_score_3[0]
    return -1 * ntk_score_1, -1 * ntk_score_2, -1 * ntk_score_3
    # return ntk_score


def parse_cmd_options(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.')
    parser.add_argument('--input_image_size', type=int, default=None,
                        help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.')
    parser.add_argument('--repeat_times', type=int, default=32)
    parser.add_argument('--gpu', type=int, default=None)
    module_opt, _ = parser.parse_known_args(argv)
    return module_opt

if __name__ == "__main__":
    opt = global_utils.parse_cmd_options(sys.argv)
    args = parse_cmd_options(sys.argv)
    the_model = ModelLoader.get_model(opt, sys.argv)
    if args.gpu is not None:
        the_model = the_model.cuda(args.gpu)


    start_timer = time.time()

    for repeat_count in range(args.repeat_times):
        ntk = compute_NTK_score(gpu=args.gpu, model=the_model,
                             resolution=args.input_image_size, batch_size=args.batch_size)
        RN = compute_RN_score(model=the_model, batch_size=args.batch_size, image_size=args.input_image_size,
                              num_batch=1, gpu=args.gpu)
        RN = 0
        the_score = RN + ntk
    time_cost = (time.time() - start_timer) / args.repeat_times

    print(f'ntk={the_score:.4g}, time cost={time_cost:.4g} second(s)')