import torch
import torch.nn as nn
from functorch import make_functional, vmap, jacrev
import numpy as np
import kmedoids


class NtkFcun(nn.Module):
    def __init__(self, model):
        super(NtkFcun, self).__init__()
        fnet, params = make_functional(model)
        self.fnet_single = lambda param, x: fnet(param,x.unsqueeze(0)).squeeze(0)
        self.params = params

    def empirical_ntk(self, x, xt):
        # Compute J(x1)
        jac1 = vmap(jacrev(self.fnet_single), (None, 0))(self.params, x)
        jac1 = [j.flatten(2) for j in jac1]
        
        # Compute J(x2)
        jac2 = vmap(jacrev(self.fnet_single), (None, 0))(self.params, xt)
        jac2 = [j.flatten(2) for j in jac2]
        
        # Compute J(x1) @ J(x2).T
        einsum_expr = 'Naf,Maf->NM'

        result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
        result = result.sum(0)
        return result       

    def single(self, x, xt):
        x.requires_grad_()
        xt.requires_grad_()
        ntk = self.empirical_ntk(x, xt)
        ntk1 = torch.autograd.grad(ntk, x,retain_graph=True,create_graph=True)[0]
        ntk2 = torch.autograd.grad(ntk1[0,0], xt)[0]
        return ntk1, ntk2
    
    def ntkplus(self, x, xt):
        n,d = x.shape
        ntk1 = torch.zeros((n,n))
        ntk2 = torch.zeros((n,n))
        for i in range(n):
            for j in range(n):
                ntk1_temp, ntk2_temp = self.single(x[i,:].view(1,d), xt[j,:].view(1,d))
                ntk1[i,j] = ntk1_temp[0,0]
                ntk2[i,j] = ntk2_temp[0,0]
        
        return ntk1, ntk2

    def get_weight(self, x, xt):
        ntk = self.empirical_ntk(x, xt)
        ntk1, ntk2 = self.ntkplus(x, xt)
        Tr_regular = torch.sum(torch.diag(ntk)).item()
        Tr_gradient = torch.sum(torch.diag(ntk2)).item()
        weight = (Tr_regular + Tr_gradient)*np.asarray([1/Tr_regular, 1/Tr_gradient])
        return weight
    
    def get_angle(self, x, xt):
        ntk = self.empirical_ntk(x, xt)
        ntk1, ntk2 = self.ntkplus(x, xt)
        ntk = ntk.detach().cpu().numpy()
        ntk1 = ntk1.detach().cpu().numpy()
        ntk2 = ntk2.detach().cpu().numpy()     
        angle = np.diag(ntk1)/np.sqrt(np.diag(ntk2))/np.sqrt(np.diag(ntk))   
        return angle
    
    def get_batch_weight(self,x,xt):
        ntk = self.empirical_ntk(x,xt)
        weight = torch.diag(ntk)/torch.sum(torch.diag(ntk))
        return weight
    