import math
import numpy as np
import scipy as sp
import random
import copy
import gc
from tqdm import tqdm

class GNTK(object):
    """
    implement the Graph Neural Tangent Kernel
    """
    def __init__(self, num_layers, num_mlp_layers, jk, scale):
        """
        num_layers: number of layers in the neural networks (including the input layer)
        num_mlp_layers: number of MLP layers
        jk: a bool variable indicating whether to add jumping knowledge
        scale: the scale used aggregate neighbors [uniform, degree]
        """
        self.num_layers = num_layers
        self.num_mlp_layers = num_mlp_layers
        self.jk = jk
        self.scale = scale
        assert(scale in ['uniform', 'degree'])
    
    def __next_diag(self, S):
        """
        go through one normal layer, for diagonal element
        S: covariance of last layer
        """
        diag = np.sqrt(np.diag(S))
        S = S / diag[:, None] / diag[None, :]
        S = np.clip(S, -1, 1)
        # dot sigma
        DS = (math.pi - np.arccos(S)) / math.pi
        S = (S * (math.pi - np.arccos(S)) + np.sqrt(1 - S * S)) / np.pi
        S = S * diag[:, None] * diag[None, :]
        return S, DS, diag

    def __adj_diag(self, S, adj_block, N, scale_mat):
        """
        go through one adj layer
        S: the covariance
        adj_block: the adjacency relation
        N: number of vertices
        scale_mat: scaling matrix
        """
        #print (adj_block)
        #print (scale_mat)
        return adj_block.dot(S.reshape(-1)).reshape(N, N) * scale_mat

    def __next(self, S, diag1, diag2):
        """
        go through one normal layer, for all elements
        """
        S = S / diag1[:, None] / diag2[None, :]
        S = np.clip(S, -1, 1)
        DS = (math.pi - np.arccos(S)) / math.pi
        S = (S * (math.pi - np.arccos(S)) + np.sqrt(1 - S * S)) / np.pi
        S = S * diag1[:, None] * diag2[None, :]
        return S, DS
    
    def __adj(self, S, adj_block, N1, N2, scale_mat):
        """
        go through one adj layer, for all elements
        """
        return adj_block.dot(S.reshape(-1)).reshape(N1, N2) * scale_mat

    def dropres(self,adj):
        num_drop = 3900
        N = adj.shape[0]
        adj_copy = copy.deepcopy(adj)

        adj_copy[range(N), range(N)] = 0
        #print(np.sum(adj_copy) / 2)
        
        ax, ay = adj_copy.nonzero()
        index = np.arange(len(ax))
        id_drop = random.choices(index, k=num_drop)
        adj_copy[ax[id_drop], ay[id_drop]] = 0
        adj_copy[ay[id_drop], ax[id_drop]] = 0
        adj_copy[range(N), range(N)] = 1
        adj_block = sp.sparse.kron(adj_copy, adj_copy)
        if self.scale == 'uniform':
            scale_mat = 1.
        else:
            scale_mat = 1. / np.array(np.sum(adj_copy, axis=1) * np.sum(adj_copy, axis=0))
        return adj_block, scale_mat
      
    def diag_new(self, node_features, adj):
        N = adj.shape[0]
        ntk_total = 0
        ntk_jk_total = 0

        for i in tqdm(range(100)):
            adj_block, scale_mat = self.dropres(adj)
            jump_ntk = 0
            sigma = np.matmul(node_features, node_features.T)
            jump_ntk += sigma
            sigma = self.__adj_diag(sigma, adj_block, N, scale_mat)
            ntk = np.copy(sigma)

            for layer in range(1, self.num_layers):
                for mlp_layer in range(self.num_mlp_layers):
                    sigma, dot_sigma, diag = self.__next_diag(sigma)
                    ntk = ntk * dot_sigma + sigma

                jump_ntk += ntk
			
                if layer != self.num_layers - 1:
                    adj_block, scale_mat = self.dropres(adj)
                    sigma = self.__adj_diag(sigma, adj_block, N, scale_mat)
                    ntk = self.__adj_diag(ntk, adj_block, N, scale_mat)
            ntk_total = ntk_total + ntk
            ntk_jk_total = ntk_jk_total + jump_ntk
            
        if self.jk:
            return ntk_jk_total
        else:
            return ntk_total
        
    def diag(self, node_features, adj):
        N = adj.shape[0]
        if self.scale == 'uniform':
            scale_mat = 1.
        else:
            scale_mat = 1. / np.array(np.sum(adj, axis=1) * np.sum(adj, axis=0))

        diag_list = []
        adj_block = sp.sparse.kron(adj, adj)

        # input covariance
        sigma = np.matmul(node_features, node_features.T)
        sigma = self.__adj_diag(sigma, adj_block, N, scale_mat)
        ntk = np.copy(sigma)
		
        
        for layer in range(1, self.num_layers):
            for mlp_layer in range(self.num_mlp_layers):
                sigma, dot_sigma, diag = self.__next_diag(sigma)
                diag_list.append(diag)
                ntk = ntk * dot_sigma + sigma
            # if not last layer
            if layer != self.num_layers - 1:
                sigma = self.__adj_diag(sigma, adj_block, N, scale_mat)
                ntk = self.__adj_diag(ntk, adj_block, N, scale_mat)
        return ntk  


    def gntk(self, node_features, diag_list, adj):
        """
        compute the GNTK value \Theta(g1, g2)
        g1: graph1
        g2: graph2
        diag_list1, diag_list2: g1, g2's the diagonal elements of covariance matrix in all layers
        A1, A2: g1, g2's adjacency matrix
        """
        
        n1 = adj.shape[0]
        n2 = adj.shape[0]
        
        if self.scale == 'uniform':
            scale_mat = 1.
        else:
            scale_mat = 1. / np.array(np.sum(adj, axis=1) * np.sum(adj, axis=0))
  
        adj_block = sp.sparse.kron(adj, adj)
        
        
        jump_ntk = 0
        sigma = np.matmul(node_features, node_features.T)
        jump_ntk += sigma
        sigma = self.__adj(sigma, adj_block, n1, n2, scale_mat)
          
        ntk = np.copy(sigma)
        

        for layer in range(1, self.num_layers):
            for mlp_layer in range(self.num_mlp_layers):
                sigma, dot_sigma = self.__next(sigma, 
                                    diag_list[(layer - 1) * self.num_mlp_layers + mlp_layer],
                                    diag_list[(layer - 1) * self.num_mlp_layers + mlp_layer])
                ntk = ntk * dot_sigma + sigma
            jump_ntk += ntk
            # if not last layer
            if layer != self.num_layers - 1:
                sigma = self.__adj(sigma, adj_block, n1, n2, scale_mat)
                ntk = self.__adj(ntk, adj_block, n1, n2, scale_mat)  
            

        if self.jk:
            return jump_ntk
        else:
            return ntk, sigma_list, ntk_list
