import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv, SGConv, SimpleConv, Sequential
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_scatter import scatter_add
from torch_geometric.nn import MessagePassing
import math

class TransMeanAGG(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add') 

    def forward(self, x, edge_index):
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Step 2: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # Step 3: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)
        return out
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class BaseAGG(nn.Module):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(BaseAGG, self).__init__()
        self.version = version
        self.is_batch = is_batch
        
        if self.version != "base":
            self.er_matrix = nn.Parameter(torch.empty(size=(in_features, out_features)))

    def reset_parameters(self):
        if self.agg is not None and  hasattr(self.agg, 'reset_parameters'):
            self.agg.reset_parameters()
        if self.version!="base":
            torch.nn.init.zeros_(self.er_matrix)

    def conv(self, x, edge_index):
        raise NotImplementedError("The 'conv' method must be implemented in subclasses.")

    def forward(self, x, edge_index, coeff):
        h = self.conv(x, edge_index)

        if self.version != "base":
            if self.version == "simple":
                h+= torch.matmul(x, self.er_matrix)
            else:
                h += coeff * torch.matmul(x, self.er_matrix)
            
        return h

class GCN_AGG(BaseAGG):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(GCN_AGG, self).__init__(version, in_features, out_features, is_batch)
        if is_batch is False:
            self.agg = GCNConv(in_features, out_features)
        else:
            self.agg = nn.Sequential(
                nn.Linear(in_features, out_features),
                SimpleConv(aggr ='mean', combine_root ='self_loop')
            )

    def conv(self, x, edge_index):
        if self.is_batch is False:
            return self.agg(x, edge_index)
        else:
            h = self.agg[0](x)
            return self.agg[1](h, edge_index)


class SAGE_AGG(BaseAGG):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(SAGE_AGG, self).__init__(version, in_features, out_features, is_batch)
        self.agg = SAGEConv(in_features, out_features)

    def conv(self, x, edge_index):
        return self.agg(x, edge_index)
    
class GIN_AGG(BaseAGG):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(GIN_AGG, self).__init__(version, in_features, out_features, is_batch)
        mlp = nn.Sequential(
                nn.Linear(in_features, 2 * out_features),
                nn.ReLU(),
                nn.Linear(2*out_features, out_features),
            )
        self.agg = GINConv(mlp, train_eps=True)

    def conv(self, x, edge_index):
        return self.agg(x, edge_index)
    
class SGC_AGG(BaseAGG):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(SGC_AGG, self).__init__(version, in_features, out_features, is_batch)
        self.is_first = (in_features != out_features)
        
        if self.is_first == True:
            if is_batch is False:
                self.agg = nn.Sequential(
                nn.Linear(in_features, out_features),
                TransMeanAGG()
            )
            else:
                self.agg = nn.Sequential(
                nn.Linear(in_features, out_features),
                SimpleConv(aggr ='mean', combine_root ='self_loop')
            )
        else:
            if is_batch is False:
                self.agg = TransMeanAGG()
            else:
                self.agg = SimpleConv(aggr ='mean', combine_root ='self_loop')

    def conv(self, x, edge_index):
        if self.is_first is True:
            h = self.agg[0](x)
            return self.agg[1](h, edge_index)
        else:
            return self.agg(x, edge_index)

class GAT_AGG(BaseAGG):
    def __init__(self, version: str, in_features: int, out_features: int, is_batch: bool, last: bool = False):
        super(GAT_AGG, self).__init__(version, in_features, out_features, is_batch)
        if last == True:
            self.agg = GATConv(in_features, out_features, 1, concat=False, dropout=0.3)
        else:
            self.agg = GATConv(in_features, out_features // 8, 8, dropout=0.3)

    def conv(self, x, edge_index):
        return self.agg(x, edge_index)