import numpy as np
import torch


def get_ntk_n(xloader, vloader, networks, recalbn=0, train_mode=False, num_batch=None, gpu=None, num_classes=100):
    device = torch.cuda.current_device()
    ntks = []
    for network in networks:
        if train_mode:
            network.train()
        else:
            network.eval()
    ######
    grads_x = [[] for _ in range(len(networks))]
    targets_x_onehot_mean = []
    grads_y = [[] for _ in range(len(networks))]

    # for i, (inputs, targets) in enumerate(xloader):
    #     if num_batch > 0 and i >= num_batch: break
    for i, (inputs, targets) in enumerate(xloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_x_onehot_mean.append(targets_onehot_mean)
        targets_x_onehot_mean = torch.cat(targets_x_onehot_mean, 0)

        # inputs = inputs.cuda(device=device, non_blocking=True)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_x[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()
    grads_x = [torch.stack(_grads, 0) for _grads in grads_x]
    ntks = [torch.einsum('nc,mc->nm', [_grads, _grads]) for _grads in grads_x]
    conds_x = []
    for ntk in ntks:
        eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # eigenvalues = torch.linalg.eigvalsh(ntk, UPLO='U')
        # conds_x.append(np.nan_to_num((eigenvalues[0]).item(), copy=True, nan=100000.0))
        # conds_x.append(np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True, nan=100000.0))
        conds_x.append(np.nan_to_num((eigenvalues[0] / eigenvalues[-1]).item(), copy=True))
        # conds_x.append(np.nan_to_num((1 - eigenvalues[-8:].sum() / eigenvalues.sum()).item(), copy=True))
        # conds_x.append(np.nan_to_num((-1 * (eigenvalues[-1] / eigenvalues[0] / 10 + eigenvalues[-8:].sum() / eigenvalues.sum())).item(), copy=True))
        # conds_x.append(np.nan_to_num((-1 * (eigenvalues[-8:].sum() / eigenvalues.sum() + 1 / torch.log10(eigenvalues.sum() / len(eigenvalues)))).item(), copy=True))
        # new_eigenvalues = torch.square(eigenvalues)
        # conds_x.append(np.nan_to_num((1 - new_eigenvalues[-8:].sum() / new_eigenvalues.sum()).item(), copy=True))
        '''
        n = len(eigenvalues)
        new_eigenvalues = torch.zeros(n)
        for i in range(len(eigenvalues)):
            # new_eigenvalues[i] = eigenvalues[i] * (math.factorial(n - 1) // (math.factorial(i) * math.factorial(n - 1 - i)))
            if i < n / 2:
                new_eigenvalues[i] = eigenvalues[i] * (i + 1) / n * 4
            else:
                new_eigenvalues[i] = eigenvalues[i] * (n - i) / n * 4
        conds_x.append(np.nan_to_num((1 - new_eigenvalues[-8:].sum() / new_eigenvalues.sum()).item(), copy=True))
        '''
    targets_y_onehot_mean = []
    # Val / Test set
    for i, (inputs, targets) in enumerate(vloader):
        if num_batch > 0 and i >= num_batch: break
        inputs = inputs.cuda(device=device, non_blocking=True)
        targets = targets.cuda(device=device, non_blocking=True)
        targets_onehot = torch.nn.functional.one_hot(targets, num_classes=num_classes).float()
        targets_onehot_mean = targets_onehot - targets_onehot.mean(0)
        targets_y_onehot_mean.append(targets_onehot_mean)
        targets_y_onehot_mean = torch.cat(targets_y_onehot_mean, 0)
        for net_idx, network in enumerate(networks):
            network.zero_grad()
            if gpu is not None:
                inputs_ = inputs.clone().cuda(device=device, non_blocking=True)
            else:
                inputs_ = inputs.clone()

            logit = network(inputs_)
            if isinstance(logit, tuple):
                logit = logit[1]  # 201 networks: return features and logits

            for _idx in range(len(inputs_)):
                logit[_idx:_idx+1].backward(torch.ones_like(logit[_idx:_idx+1]), retain_graph=True)
                
                grad = []
                for name, W in network.named_parameters():
                    if 'weight' in name and W.grad is not None:
                        grad.append(W.grad.view(-1).detach())
                grads_y[net_idx].append(torch.cat(grad, -1))
                network.zero_grad()
                if gpu is not None:
                    torch.cuda.empty_cache()
    grads_y = [torch.stack(_grads, 0) for _grads in grads_y]
    prediction_mses = []

    for net_idx in range(len(networks)):
        # _ntk_yx = torch.einsum('nc,mc->nm', [grads_y, grads_x])
        _ntk_yx = [torch.einsum('nc,mc->nm', [_grads_y, _grads_x]) for _grads_y, _grads_x in zip(grads_y, grads_x)]
        PY = torch.einsum('jk,kl,lm->jm', _ntk_yx[0], torch.inverse(ntks[0]), targets_x_onehot_mean)
        prediction_mses.append(((PY - targets_y_onehot_mean)**2).sum(1).mean(0).item())
    ######

    return conds_x, prediction_mses


def compute_NTK_score(xloader, vloader, model, gpu, resolution, batch_size):
    ntk_score, prediction_mses = get_ntk_n(xloader, vloader, [model], recalbn=0, train_mode=True, num_batch=1, gpu=gpu, num_classes=100)
    # return -1 * ntk_score
    return ntk_score[0], -1 * prediction_mses[0]