# Copied and modified from https://github.com/tk-rusch/gradientgating
# License: MIT
# Original author: Konstantin Rusch
# Description: This class implements G2 as described in [Gradient Gating for Deep Multi-Rate Learning on Graphs, 2023]
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv, SAGEConv

from src.models.model_utils import ACTIVATION_MAPPING


class G2(nn.Module):
    def __init__(self, conv, p=2., conv_type='GCN', activation=nn.ReLU(), msg_passing_method:str= None):
        super(G2, self).__init__()
        self.conv = conv
        self.p = p
        self.activation = activation
        self.conv_type = conv_type if msg_passing_method is None else msg_passing_method
        from torch_scatter import scatter
        self.scatter = scatter

    def forward(self, X, edge_index):
        n_nodes = X.size(0)
        if self.conv_type == 'GAT':
            X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
        else:
            X = self.activation(self.conv(X, edge_index))

        gg = torch.tanh(self.scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1),
                                edge_index[0], 0, dim_size=X.size(0), reduce='mean'))

        return gg

class G2_GNN(nn.Module):
    def __init__(self,
                 hidden_dim,
                 output_dim,
                 n_message_passings,
                 final_activation,
                 conv_type='GCN',
                 msg_passing_method=None,
                 p=2.,
                 drop_in=0,
                 dropout=0,
                 use_gg_conv=True,
                 pooling: str = None):
        super(G2_GNN, self).__init__()
        self.conv_type = conv_type if msg_passing_method is None else msg_passing_method
        self.enc = nn.LazyLinear(hidden_dim)
        self.dec = nn.Linear(hidden_dim, output_dim)
        self.drop_in = drop_in
        self.drop = dropout
        self.nlayers = n_message_passings
        self.activation = ACTIVATION_MAPPING[final_activation]
        if conv_type == 'GCN' or conv_type == 'gcn':
            self.conv = GCNConv(hidden_dim, hidden_dim)
            if use_gg_conv == True:
                self.conv_gg = GCNConv(hidden_dim, hidden_dim)
        elif conv_type == 'GATv2' or conv_type == 'gat_v2' or conv_type == 'gatv2' :
            self.conv = GATv2Conv(hidden_dim, hidden_dim, add_self_loops=True)
            if use_gg_conv == True:
                self.conv_gg = GATv2Conv(hidden_dim, hidden_dim, add_self_loops=True)
        elif conv_type == 'SAGE' or conv_type == 'sage':
            self.conv = SAGEConv(hidden_dim, hidden_dim)
            if use_gg_conv == True:
                self.conv_gg = SAGEConv(hidden_dim, hidden_dim)
        else:
            print('specified graph conv not implemented')

        if use_gg_conv == True:
            self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU())
        else:
            self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU())

        self.final_activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        X = x
        n_nodes = X.size(0)
        X = F.dropout(X, self.drop_in, training=self.training)
        X = torch.relu(self.enc(X))

        for i in range(self.nlayers):
            if self.conv_type == 'GAT':
                X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
            else:
                X_ = torch.relu(self.conv(X, edge_index))
            tau = self.G2(X, edge_index)
            X = (1 - tau) * X + tau * X_

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = X

        X = F.dropout(X, self.drop, training=self.training)
        X = torch.relu(self.dec(X))

        X = self.final_activation(X)

        output = X if self.training else (X, x_L)

        return output

