'''
Copyright (C) 2010-2021 Alibaba Group Holding Limited.

This file is modified from:
https://github.com/VITA-Group/TENAS
'''

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from torch import nn
import global_utils, argparse, ModelLoader, time

from PlainNet.basic_blocks import RELU

import torch.nn.functional as F

def cross_entropy(logit, target):
    # target must be one-hot format!!
    prob_logit = F.softmax(logit, dim=1)
    target_logit = F.softmax(target, dim=1)
    loss = -(target_logit * prob_logit).sum(dim=1).mean()
    return loss

def logdet(K):
    s, ld = np.linalg.slogdet(K)
    return ld

# 计算神经网络中线性激活区域的数量
class LinearRegionCount(object):
    """Computes and stores the average and current value"""
    def __init__(self, n_samples, gpu=None):
        self.ActPattern = {}
        self.n_LR = -1
        self.n_samples = n_samples
        self.ptr = 0
        self.activations = None
        self.gpu = gpu


    @torch.no_grad()
    def update2D(self, activations):
        n_batch = activations.size()[0]
        n_neuron = activations.size()[1]
        self.n_neuron = n_neuron
        if self.activations is None:
            self.activations = torch.zeros(self.n_samples, n_neuron)
            if self.gpu is not None:
                self.activations = self.activations.cuda(self.gpu)
        self.activations[self.ptr:self.ptr+n_batch] = torch.sign(activations)  # after ReLU
        # print("activations", self.activations)
        self.ptr += n_batch

    @torch.no_grad()
    def calc_LR(self):
        
        res = torch.matmul(self.activations, (1-self.activations).T)
        print(self.activations.shape)
        # print(res)
        res = res + res.T
        res = 1 - torch.sign(res)
        # print(torch.sign(res))
        res = res.sum(1)
        res = 1. / res.float()
        self.n_LR = res.sum().item()
        del self.activations, res
        self.activations = None
        if self.gpu is not None:
            torch.cuda.empty_cache()
        
        '''
        linear_region_count = torch.sum(self.activations > 0).item()
        print(linear_region_count)
        '''

    @torch.no_grad()
    def update1D(self, activationList):
        code_string = ''
        for key, value in activationList.items():
            n_neuron = value.size()[0]
            for i in range(n_neuron):
                if value[i] > 0:
                    code_string += '1'
                else:
                    code_string += '0'
        if code_string not in self.ActPattern:
            self.ActPattern[code_string] = 1

    def getLinearReginCount(self):
        if self.n_LR == -1:
            self.calc_LR()
        return self.n_LR

# 收集线性激活区域的数量
class Linear_Region_Collector:
    def __init__(self, models=[], input_size=(64, 3, 32, 32), gpu=None,
                 sample_batch=1, dataset=None, data_path=None, seed=0):
        self.models = []
        self.input_size = input_size  # BCHW
        self.sample_batch = sample_batch
        # self.input_numel = reduce(mul, self.input_size, 1)
        self.interFeature = []
        self.dataset = dataset
        self.data_path = data_path
        self.seed = seed
        self.gpu = gpu
        self.device = torch.device('cuda:{}'.format(self.gpu)) if self.gpu is not None else torch.device('cpu')
        # print('Using device:{}'.format(self.device))

        self.reinit(models, input_size, sample_batch, seed)


    def reinit(self, models=None, input_size=None, sample_batch=None, seed=None):
        if models is not None:
            assert isinstance(models, list)
            del self.models
            self.models = models
            
            for model in self.models:
                # self.register_hook(model)
                new_list = []
                for mlist in model.module_list:
                    now_list = list(mlist.module_list)
                    new_list.extend(now_list)
                model = nn.Sequential(*new_list)
                # print(model)
                self.register_hook(model)
            self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch, gpu=self.gpu) for _ in range(len(models))]
        if input_size is not None or sample_batch is not None:
            if input_size is not None:
                self.input_size = input_size  # BCHW
                # self.input_numel = reduce(mul, self.input_size, 1)
            if sample_batch is not None:
                self.sample_batch = sample_batch
            # if self.data_path is not None:
            #     self.train_data, _, class_num = get_datasets(self.dataset, self.data_path, self.input_size, -1)
            #     self.train_loader = torch.utils.data.DataLoader(self.train_data, batch_size=self.input_size[0], num_workers=16, pin_memory=True, drop_last=True, shuffle=True)
            #     self.loader = iter(self.train_loader)
        if seed is not None and seed != self.seed:
            self.seed = seed
            torch.manual_seed(seed)
            if self.gpu is not None:
                torch.cuda.manual_seed(seed)
        del self.interFeature
        self.interFeature = []
        if self.gpu is not None:
            torch.cuda.empty_cache()

    def clear(self):
        self.LRCounts = [LinearRegionCount(self.input_size[0]*self.sample_batch) for _ in range(len(self.models))]
        del self.interFeature
        self.interFeature = []
        if self.gpu is not None:
            torch.cuda.empty_cache()

    def register_hook(self, model):
        '''
        for m in model.modules():
            if isinstance(m, nn.ReLU):
                m.register_forward_hook(hook=self.hook_in_forward)
        '''
        for layer_idx, (name, module) in enumerate(model.named_children()):
            # print(name, module)
            if isinstance(module, RELU):
                module.register_forward_hook(hook=self.hook_in_forward)

    def hook_in_forward(self, module, input, output):
        if isinstance(input, tuple) and len(input[0].size()) == 4:
            self.interFeature.append(output.detach())  # for ReLU

    def forward_batch_sample(self):
        for _ in range(self.sample_batch):
            # try:
            #     inputs, targets = self.loader.next()
            # except Exception:
            #     del self.loader
            #     self.loader = iter(self.train_loader)
            #     inputs, targets = self.loader.next()
            inputs = torch.randn(self.input_size, device=self.device)

            for model, LRCount in zip(self.models, self.LRCounts):
                self.forward(model, LRCount, inputs)

        return [LRCount.getLinearReginCount() for LRCount in self.LRCounts]

    def forward(self, model, LRCount, input_data):
        self.interFeature = []
        with torch.no_grad():
            output = model.forward(input_data.cuda())
            # print(output.shape)
            # model.forward(input_data)
            if len(self.interFeature) == 0: return
            # print(self.interFeature.shape)
            feature_data = torch.cat([f.view(input_data.size(0), -1) for f in self.interFeature], 1)
            # print(feature_data)
            LRCount.update2D(feature_data)


def compute_RN_score(model: nn.Module,  batch_size=None, image_size=None, num_batch=None, gpu=None):
    # # just debug
    # gpu = 0
    # import ModelLoader
    # model = ModelLoader._get_model_(arch='resnet18', num_classes=1000, pretrained=False, opt=None, argv=None)
    #
    # if gpu is not None:
    #     model = model.cuda(gpu)
    lrc_model = Linear_Region_Collector(models=[model], input_size=(batch_size, 3, image_size, image_size),
                                        gpu=gpu, sample_batch=num_batch)
    num_linear_regions = float(lrc_model.forward_batch_sample()[0])
    del lrc_model
    torch.cuda.empty_cache()
    return num_linear_regions



import numpy as np
import torch

# 重新计算BN层的统计信息
def recal_bn(network, xloader, recalbn, device):
    for m in network.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.running_mean.data.fill_(0)
            m.running_var.data.fill_(0)
            m.num_batches_tracked.data.zero_()
            m.momentum = None
    network.train()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(xloader):
            if i >= recalbn: break
            inputs = inputs.cuda(device=device, non_blocking=True)
            _, _ = network(inputs)
    return network


def get_ntk_n(xloader, networks, recalbn=0, train_mode=False, num_batch=None,
              batch_size=None, image_size=None, gpu=None):
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')

    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    targets_x_onehot_mean = []

    # for i, (inputs, targets) in enumerate(xloader):
    #     if num_batch > 0 and i >= num_batch: break
    '''
    for i in range(num_batch):
        inputs = torch.randn((batch_size, 3, image_size, image_size), device=device)
        # inputs = inputs.cuda(device=device, non_blocking=True)
        # targets = torch.randint(0, num_classes, (64,))
    '''
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        size = (8,) + inputs.size()[1:]
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=100).float()
        # print(targets_onehot.shape)
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            # print('network', network)
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                # print("yes")
                logit = logit[1]  # 201 networks: return features and logits
            # print('logit', logit.shape)
            # print(logit[0:1])
            
            for _idx in range(len(inputs_)):
                # logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                # sum_logit = 1 / torch.sqrt(torch.tensor(logit.shape[1])) * torch.sum(logit[_idx:_idx+1])
                # sum_logit.backward(retain_graph=True)

                loss = cross_entropy(logit[_idx:_idx+1], targets_onehot[_idx:_idx+1])
                loss.backward(retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                # print('grad', len(grad))
                grads[net_idx].append(torch.cat(grad, -1))
                # print('cat', torch.cat(grad, -1).shape)
                # print('grads', len(grads[0][0]))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()

    ######
    # print('grads', len(grads[0][0]))
    grads = [torch.stack(_grads, 0) for _grads in grads]
    # print('grads', len(grads[0][0]))
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    # print('ntks', len(ntks[0][0]))
    conds = []
    for ntk in ntks:
        # print("ntk", ntk.shape)
        # eigenvalues, _ = torch.symeig(ntk)  # ascending
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True, nan=100000.0))
        conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0]).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum() + eigenvalues[-1] / eigenvalues[0] / 12).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[0] / eigenvalues[-1] - eigenvalues[-20:].sum() / eigenvalues.sum()).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] + 4 * (eigenvalues[-10:].sum() / eigenvalues.sum())).item(), copy=True))
        # new_eigenvalues = torch.square(eigenvalues)
        # conds.append(np.nan_to_num((new_eigenvalues[-10:].sum() / new_eigenvalues.sum()).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-8:].sum() / eigenvalues.sum() + 0.5 / torch.log10(eigenvalues.sum())).item(), copy=True))
        # conds.append(np.nan_to_num((eigenvalues[-1] / eigenvalues[0] + 3 / torch.log10(eigenvalues[0])).item(), copy=True))
        # conds.append(np.nan_to_num(torch.var(eigenvalues).item(), copy=True)) 
        '''
        n = len(eigenvalues)
        new_eigenvalues = torch.zeros(n)
        for i in range(len(eigenvalues)):
            # new_eigenvalues[i] = eigenvalues[i] * (math.factorial(n - 1) // (math.factorial(i) * math.factorial(n - 1 - i)))
            if i < n / 2:
                new_eigenvalues[i] = eigenvalues[i] * (i + 1) / n * 4
            else:
                new_eigenvalues[i] = eigenvalues[i] * (n - i) / n * 4
        conds.append(np.nan_to_num((new_eigenvalues[-10:].sum() / new_eigenvalues.sum()).item(), copy=True))
        '''
    return conds

def get_ntk_lga(xloader, networks, recalbn=0, train_mode=False, num_batch=-1,
                batch_size=None, image_size=None, gpu=None):
    '''
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    targets_x_onehot_mean = []
    for network in networks:
        # network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=100).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    '''
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')

    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    targets_x_onehot_mean = []

    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        size = (8,) + inputs.size()[1:]
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=100).float()
        # print(targets_onehot.shape)
        targets_onehot_mean = targets_onehot
        # targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            # print('network', network)
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            
            for _idx in range(len(inputs_)):

                loss = cross_entropy(logit[_idx:_idx+1], targets_onehot[_idx:_idx+1])
                loss.backward(retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()

    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    lgas = []
    labels = targets_x_onehot_mean
    for mat in ntks:
        mat_normalized = mat - torch.mean(mat)
        labels_normalized = torch.matmul(labels, labels.T)
        labels_normalized[labels_normalized<1] = -1
        labels_normalized = labels_normalized - torch.mean(labels_normalized)
        score = mat_normalized * labels_normalized / (torch.norm(mat_normalized, 2) * torch.norm(labels_normalized, 2))
        lgas.append(torch.sum(score).item())
    return lgas
    

def get_ntk_fnorm(xloader, networks, recalbn=0, train_mode=False, num_batch=-1,
                  batch_size=None, image_size=None, gpu=None):
    '''
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')
    # if recalbn > 0:
    #     network = recal_bn(network, xloader, recalbn, device)
    #     if network_2 is not None:
    #         network_2 = recal_bn(network_2, xloader, recalbn, device)
    ntks = []
    for network in networks:
        # network_weight_gaussian_init(network)
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    grads_val = [[] for _ in range(len(networks))]
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                torch.cuda.empty_cache()
    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    '''
    if gpu is not None:
        device = torch.device('cuda:{}'.format(gpu))
    else:
        device = torch.device('cpu')

    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads = [[] for _ in range(len(networks))]
    targets_x_onehot_mean = []

    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        size = (8,) + inputs.size()[1:]
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=100).float()
        # print(targets_onehot.shape)
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            # print('network', network)
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits
            
            for _idx in range(len(inputs_)):

                loss = cross_entropy(logit[_idx:_idx+1], targets_onehot[_idx:_idx+1])
                loss.backward(retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()

    ######
    grads = [torch.stack(_grads, 0) for _grads in grads]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads]
    fnorms = []
    for ntk in ntks:
        fnorms.append(torch.norm(ntk, p="fro").item())
    return fnorms


def compute_NTK_score(xloader, gpu, model, resolution, batch_size):
    ntk_score = get_ntk_n(xloader, [model], recalbn=0, train_mode=True, num_batch=1,
                           batch_size=batch_size, image_size=resolution, gpu=gpu)[0]
    return -1 * ntk_score
    # return ntk_score
    
def compute_NTK_LGA_score(xloader, gpu, model, resolution, batch_size):
    ntk_score = get_ntk_lga(xloader, [model], recalbn=0, train_mode=True, num_batch=1,
                           batch_size=batch_size, image_size=resolution, gpu=gpu)[0]
    return -1 * ntk_score
    # return ntk_score

def compute_KNAS_score(xloader, gpu, model, resolution, batch_size):
    ntk_score = get_ntk_fnorm(xloader, [model], recalbn=0, train_mode=True, num_batch=1,
                           batch_size=batch_size, image_size=resolution, gpu=gpu)[0]
    return -1 * ntk_score
    # return ntk_score


def parse_cmd_options(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16, help='number of instances in one mini-batch.')
    parser.add_argument('--input_image_size', type=int, default=None,
                        help='resolution of input image, usually 32 for CIFAR and 224 for ImageNet.')
    parser.add_argument('--repeat_times', type=int, default=32)
    parser.add_argument('--gpu', type=int, default=None)
    module_opt, _ = parser.parse_known_args(argv)
    return module_opt

if __name__ == "__main__":
    opt = global_utils.parse_cmd_options(sys.argv)
    args = parse_cmd_options(sys.argv)
    the_model = ModelLoader.get_model(opt, sys.argv)
    if args.gpu is not None:
        the_model = the_model.cuda(args.gpu)


    start_timer = time.time()

    for repeat_count in range(args.repeat_times):
        ntk = compute_NTK_score(gpu=args.gpu, model=the_model,
                             resolution=args.input_image_size, batch_size=args.batch_size)
        RN = compute_RN_score(model=the_model, batch_size=args.batch_size, image_size=args.input_image_size,
                              num_batch=1, gpu=args.gpu)
        RN = 0
        the_score = RN + ntk
    time_cost = (time.time() - start_timer) / args.repeat_times

    print(f'ntk={the_score:.4g}, time cost={time_cost:.4g} second(s)')