import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import torch
from torch_geometric.nn import global_add_pool


from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import math
from torch_scatter import scatter_add
from torch_geometric.utils import laplacian
from torch_geometric.utils import to_dense_batch
from torch_scatter import scatter_mean, scatter_sum
import numpy as np

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import get_laplacian
from typing import Optional


class JacobiConv(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        K: int,
        normalization: Optional[str] = 'sym',
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert K > 0
        assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = normalization
        self.lins = torch.nn.ModuleList([
            Linear(in_channels, out_channels, bias=False,
                   weight_initializer='glorot') for _ in range(K)
        ])
        self.bn = nn.BatchNorm1d(out_channels)
        self.alpha = Parameter(torch.rand(1) * 2)#Parameter(torch.ones(1))
        self.beta = Parameter(torch.rand(1) * 2)#Parameter(torch.ones(1))

        if bias:
            self.bias = Parameter(Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
        zeros(self.bias)


    def __norm__(
        self,
        edge_index: Tensor,
        num_nodes: Optional[int],
        edge_weight: OptTensor,
        normalization: Optional[str],
        lambda_max: OptTensor = None,
        dtype: Optional[int] = None,
        batch: OptTensor = None,
    ):
        edge_index, edge_weight = get_laplacian(edge_index, edge_weight,
                                                normalization, dtype,
                                                num_nodes)
        assert edge_weight is not None

        if lambda_max is None:
            lambda_max = 2.0 * edge_weight.max()
        elif not isinstance(lambda_max, Tensor):
            lambda_max = torch.tensor(lambda_max, dtype=dtype,
                                      device=edge_index.device)
        assert lambda_max is not None

        if batch is not None and lambda_max.numel() > 1:
            lambda_max = lambda_max[batch[edge_index[0]]]

        edge_weight = (2.0 * edge_weight) / lambda_max
        edge_weight.masked_fill_(edge_weight == float('inf'), 0)

        edge_weight -= 1
        edge_weight /= 2

        loop_mask = edge_index[0] == edge_index[1]
        edge_weight[loop_mask] -= 1

        return edge_index, edge_weight

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: OptTensor = None,
        batch: OptTensor = None,
        lambda_max: OptTensor = None,
    ) -> Tensor:

        edge_index, norm = self.__norm__(
            edge_index,
            x.size(self.node_dim),
            edge_weight,
            self.normalization,
            lambda_max,
            dtype=x.dtype,
            batch=batch,
        )
        
        Pn_2 = self.bn(x)
        out = self.bn(self.lins[0](Pn_2))

        if len(self.lins) > 1:
            coef1 = self.alpha + 1
            coef2 = self.alpha + self.beta + 2

            Pn_1 = coef1 * self.bn(x) + coef2 * self.bn(self.propagate(edge_index, x=x, norm=norm))
            
            Pn_1 = self.bn(Pn_1)

            out = out + self.lins[1](Pn_1)

            out = self.bn(out)

        i = 2
        for lin in self.lins[2:]:
            a = i + self.alpha
            b = i + self.beta
            c = a + b
            
            coef1 = ((c - 1) * c * (c - 2)) / (i * (c - i) * (c - 2))
            coef2 = ((c - 1) * c * (c - 2) + (c - 1) * (a - b) * (c - 2 * i)) / (2 * i * (c - i) * (c - 2))
            coef3 = (2 * (a - 1) * (b - 1)) / (2 * i * (c - i) * (c - 2))

            Pn = coef1 * self.bn(self.propagate(edge_index, x=Pn_1, norm=norm)) + coef2 * self.bn(Pn_1) - coef3 * self.bn(Pn_2)

            Pn = self.bn(Pn)

            out = out + lin.forward(Pn)

            out = self.bn(out)

            Pn_2, Pn_1 = Pn_1, Pn

            i += 1


        if self.bias is not None:
            out = out + self.bias

        return out


    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
        return norm.view(-1, 1) * x_j


class GADGNN(nn.Module):
    def __init__(self, featuredim, hdim, nclass, netype, width, depth, dropout, normalize, device):
        super(GADGNN, self).__init__()
        
        self.eweight = Parameter(torch.rand(netype) * 2)

        self.conv = []
        for i in range(width):
            self.conv.append(JacobiConv(featuredim, featuredim, depth).to(device))
        
        self.linear = nn.Linear(featuredim, featuredim)
        self.linear2 = nn.Linear(featuredim, featuredim)
        self.linear3 = nn.Linear(featuredim*len(self.conv), hdim)
        self.linear4 = nn.Linear(hdim, hdim)
        self.act = nn.LeakyReLU()

        self.device = device
        
        #self.linear7 = nn.Linear(hdim * 2, nclass)
        self.linear7 = nn.Linear(hdim, nclass)

        #self.attpool = nn.Linear(hdim, 1)

        #self.bn = torch.nn.BatchNorm1d(hdim * 2)
        self.bn = torch.nn.BatchNorm1d(hdim)

        self.dp = nn.Dropout(p=dropout)
        self.normalize = normalize
        #self.reset_para()

        self.pool = global_add_pool

        self.sigmoid = nn.Sigmoid()

        self.reset_parameters()

    def reset_parameters(self):
        self.linear.reset_parameters()
        self.linear2.reset_parameters()
        self.linear3.reset_parameters()
        self.linear4.reset_parameters()
        self.linear7.reset_parameters()

    def forward(self, data):
        h = self.linear(data.x.float())
        h = self.act(h)

        h = self.linear2(h)
        h = self.act(h)

        h_final = torch.zeros([len(data.x), 0]).to(self.device)

        eweight = self.sigmoid(self.eweight[data.edge_type])
        weights = torch.ones(len(eweight)).to(self.device) * eweight

        for conv in self.conv:
            h0 = conv(h, data.edge_index, weights)
            h_final = torch.cat([h_final, h0], -1)

        h = self.linear3(h_final)
        h = self.act(h)
        
        h = self.linear4(h)
        h = self.act(h)
        

        h = self.pool(h, data.batch)

        if self.normalize:
            h = self.bn(h)

        h = self.dp(h)

        embed = self.linear7(h)
        return embed

