
import math
import wandb
import copy
import torch
import torch.nn as nn
from torch import Tensor
# from scipy.linalg import eigvals
# from torch.linalg import eigvals


class PolyConvFrame(nn.Module):
    '''
    A framework for polynomial graph signal filter.
    Args:
        conv_fn: the filter function, like PowerConv, LegendreConv,...
        depth (int): the order of polynomial.
        cached (bool): whether or not to cache the adjacency matrix. 
        alpha (float):  the parameter to initialize polynomial coefficients.
        fixed (bool): whether or not to fix to polynomial coefficients.
    '''
    def __init__(self,
                 conv_fn,
                 depth: int = 3,
                 cached: bool = True,
                 alpha: float = 1.0,
                 beta: float = 0.2,
                 fixed: float = True,):
        super().__init__()
        self.depth = depth
        self.basealpha = alpha

        self.alphas = [nn.Parameter(torch.tensor(beta * (1-beta)**i, requires_grad=True)) for i in range(self.depth)  ] 
        self.alphas.append(nn.Parameter(torch.tensor( beta** self.depth, requires_grad=True)))
        self.alphas = nn.ParameterList(self.alphas)


        self.cached = cached
        self.adj = None
        self.conv_fn = conv_fn
        self.H = []

    def forward(self, x: Tensor, adj: Tensor):
        '''
        Args:
            x: node embeddings. of shape (number of nodes, node feature dimension)
            edge_index and edge_attr: If the adjacency is cached, they will be ignored.
        '''
        # print(self.alphas)

        alphas = [self.basealpha * torch.tanh(i) for i in self.alphas]
        # alphas = [self.basealpha * i  for i in self.alphas]
        # alphas = self.alphas
        # print(alphas)
        # Ts = []
        xs = [self.conv_fn(0, [x], adj, alphas)]
        # Ts.append(t)
        # out = [alphas[0]*x]
        # out = []
        # H = []
        for L in range(1, self.depth+1):
            tx = self.conv_fn(L, xs, adj, alphas)
            xs.append(tx)
            # out.append(alphas[L]* tx @ x)
            # H.append(alphas[L] * tx)
            # Ts.append(t)
        # out = sum(out)
        out = sum(xs)
        return out


def PowerConv(L, xs, adj, alphas):
    '''
    Monomial bases.
    '''
    if L == 0: return xs[0]
    return alphas[L] * (adj @ xs[-1])


def ChebyshevConv(L, xs, adj, alphas):
    '''
    Chebyshev Bases. Please refer to our paper for the form of the bases.
    '''
    if L == 0: return xs[0]
    nx = (2 * alphas[L - 1]) * (adj @ xs[-1])
    if L > 1:
        nx -= (alphas[L - 1] * alphas[L - 2]) * xs[-2]
    return nx
    

# def ChebyshevConv(i,xs, x, alphas):
#     if i==0:
#         return torch.eye(x.shape[-1], device=x.device)
#     elif i==1:
#         return x
#     else:
#         T0=1
#         T1=x
#         for ii in range(2,i+1):
#             T2=2*x*T1-T0
#             T0,T1=T1,T2
#         return T2

def _laplacian(W):
    """Return graph Laplacian"""

    # Degree matrix.
    deg = W.sum(axis=2)
    deg_inv_sqrt = deg.pow_(-0.5)
    # deg_inv_sqrt.masked_fill_(deg_inv_sqrt == ('nan'), 0)
    deg_inv_sqrt = torch.nan_to_num(deg_inv_sqrt) 
    deg_inv_sqrt = deg_inv_sqrt.diag_embed() 
    A_norm = deg_inv_sqrt @ W @ deg_inv_sqrt

    # L = I - A_norm.
    L = torch.eye(A_norm.shape[-1], device=A_norm.device) - A_norm
    return L


def get_scaled_laplacian(adj):
    # diag = torch.diga(adj)
    # symmentric_attn = attn + tf.transpose(attn)
    # tf.matrix_set_diag(symmentric_attn, diag)

    L_tilde = _laplacian(adj)
    # lambda_max = 2.0 * L_tilde.max()
    lambda_max = torch.max(eigvals(a=L_tilde).real)


    scaled_laplacian = 2 * L_tilde / lambda_max - torch.eye(L_tilde.shape[-1], device=adj.device)
    return scaled_laplacian

def JacobiConv(L, xs, adj, alphas, a=1.0, b=1.0, l=-1.0, r=1.0):
    '''
    Jacobi Bases. Please refer to our paper for the form of the bases.
    '''
    if L == 0: return xs[0]
    if L == 1:
        coef1 = (a - b) / 2 - (a + b + 2) / 2 * (l + r) / (r - l)
        coef1 *= alphas[0]
        coef2 = (a + b + 2) / (r - l)
        coef2 *= alphas[0]
        return coef1 * xs[-1] + coef2 * (adj @ xs[-1])
    coef_l = 2 * L * (L + a + b) * (2 * L - 2 + a + b)
    coef_lm1_1 = (2 * L + a + b - 1) * (2 * L + a + b) * (2 * L + a + b - 2)
    coef_lm1_2 = (2 * L + a + b - 1) * (a**2 - b**2)
    coef_lm2 = 2 * (L - 1 + a) * (L - 1 + b) * (2 * L + a + b)
    tmp1 = alphas[L - 1] * (coef_lm1_1 / coef_l)
    tmp2 = alphas[L - 1] * (coef_lm1_2 / coef_l)
    tmp3 = alphas[L - 1] * alphas[L - 2] * (coef_lm2 / coef_l)
    tmp1_2 = tmp1 * (2 / (r - l))
    tmp2_2 = tmp1 * ((r + l) / (r - l)) + tmp2
    nx = tmp1_2 * (adj @ xs[-1]) - tmp2_2 * xs[-1]
    nx -= tmp3 * xs[-2]
    return nx


def LegendreConv(L, xs, adj, alphas):
    '''
    Legendre bases. Please refer to our paper for the form of the bases.
    '''
    if L == 0: 
        return xs[0]
    nx = (alphas[L - 1] * (2 - 1 / L)) * (adj @ xs[-1])
    if L > 1:
        nx -= (alphas[L - 1] * alphas[L - 2] * (1 - 1 / L)) * xs[-2]
    return nx

# def ChebyshevBaseConv(i,xs, x, alphas):
#     if i==0:
#         return torch.eye(x.shape[-1], device=x.device)
#     elif i==1:
#         return x
#     else:
#         T0=1
#         T1=x
#         for ii in range(2,i+1):
#             T2=2*x*T1-T0
#             T0,T1=T1,T2
#         return T2

# def ChebyshevBaseConv(L, xs, adj, alphas):
#     '''
#     Chebyshev Bases. Please refer to our paper for the form of the bases.
#     '''
#     if L == 0: return xs[0]
#     nx = (2 * alphas[L - 1]) * (adj @ xs[-1])
#     if L > 1:
#         nx -= (alphas[L - 1] * alphas[L - 2]) * xs[-2]



#     coe_tmp=torch.nn.Funtional.relu(alphas)
#     coe=coe_tmp.clone()
    
#     for i in range(self.K+1):
#         coe[i]=coe_tmp[0]*cheby(i,math.cos((self.K+0.5)*math.pi/(self.K+1)))
#         for j in range(1,self.K+1):
#             x_j=math.cos((self.K-j+0.5)*math.pi/(self.K+1))
#             coe[i]=coe[i]+coe_tmp[j]*cheby(i,x_j)
#         coe[i]=2*coe[i]/(self.K+1)


#     #L=I-D^(-0.5)AD^(-0.5)
#     # edge_index1, norm1 = get_laplacian(edge_index, edge_weight,normalization='sym', dtype=x.dtype, num_nodes=x.size(self.node_dim))

#     # L_tilde=L-I
#     # edge_index_tilde, norm_tilde= add_self_loops(edge_index1,norm1,fill_value=-1.0,num_nodes=x.size(self.node_dim))

#     Tx_0=x
#     # Tx_1=self.propagate(edge_index_tilde,x=x,norm=norm_tilde,size=None)

#     out=coe[0]/2*Tx_0+coe[1]*Tx_1

#     for i in range(2,self.K+1):
#         Tx_2=self.propagate(edge_index_tilde,x=Tx_1,norm=norm_tilde,size=None)
#         Tx_2=2*Tx_2-Tx_0
#         out=out+coe[i]*Tx_2
#         Tx_0,Tx_1 = Tx_1, Tx_2

#     return nx


