
import math
import copy
import torch
import torch.nn as nn
from torch import Tensor

class PolyConvFrame(nn.Module):
    '''
    A framework for polynomial graph signal filter.
    Args:
        conv_fn: the filter function, like PowerConv, LegendreConv,...
        depth (int): the order of polynomial.
        cached (bool): whether or not to cache the adjacency matrix. 
        alpha (float):  the parameter to initialize polynomial coefficients.
        fixed (bool): whether or not to fix to polynomial coefficients.
    '''
    def __init__(self,
                 conv_fn,
                 depth: int = 3,
                 cached: bool = True,
                 alpha: float = 1.0,
                 beta: float = 0.2,
                 fixed: float = True,):
        super().__init__()
        self.depth = depth
        self.basealpha = alpha

        self.alphas = [nn.Parameter(torch.tensor(beta * (1-beta)**i, requires_grad=True)) for i in range(self.depth)  ] 
        self.alphas.append(nn.Parameter(torch.tensor( beta** self.depth, requires_grad=True)))
        self.alphas = nn.ParameterList(self.alphas)


        self.cached = cached
        self.adj = None
        self.conv_fn = conv_fn
        self.H = []
        self.tx_lst = []

    def forward(self, x: Tensor, adj: Tensor):
        '''
        Args:
            x: node embeddings. of shape (number of nodes, node feature dimension)
            edge_index and edge_attr: If the adjacency is cached, they will be ignored.
        '''

        alphas = [self.basealpha * torch.tanh(i) for i in self.alphas]
        xs = [self.conv_fn(0, [x], adj, alphas)]

        for L in range(1, self.depth+1):
            tx = self.conv_fn(L, xs, adj, alphas)
            xs.append(tx)

        out = sum(xs)
        return out

def ChebyshevConv(L, xs, adj, alphas):
    '''
    Chebyshev Bases. Please refer to our paper for the form of the bases.
    '''
    if L == 0: return xs[0]
    nx = (2 * alphas[L - 1]) * (adj @ xs[-1])
    if L > 1:
        nx -= (alphas[L - 1] * alphas[L - 2]) * xs[-2]
    return nx
