import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv, GINConv, SAGEConv
from torch_geometric.nn import PairNorm

from src.models.gnn_interface import InterfaceGNN
from src.models.model_utils import POOLING_MAPPING
from src.models.model_utils import ACTIVATION_MAPPING


# Define the GCN model
class GCN(InterfaceGNN):
    def __init__(self,
                 hidden_dim: list,
                 n_message_passings: int,
                 output_dim: int,
                 final_activation: torch.functional,
                 pooling: str = None,
                 dropout: float = 0.0):

        hidden_dims = [hidden_dim for _ in range(n_message_passings)]

        super(GCN, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i in range(0, len(hidden_dims)):
            self.conv_layers.append(GCNConv(-1, hidden_dims[i]))
        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):

        for conv in self.conv_layers:
            x = torch.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)

        x = self.activation(self.readout(x))

        output = x if self.training else (x, x_L)

        return output

# Define the GCN model
class GIN(InterfaceGNN):
    def __init__(self,
                 hidden_dim: int,
                 n_message_passings: int,
                 output_dim: int,
                 final_activation,
                 n_update_layer: int = 1,
                 pooling: str = None,
                 n_readout: int = 1,
                 dropout: float = 0.0):

        hidden_dims = [hidden_dim for _ in range(n_message_passings)]

        super(GIN, self).__init__()

        self.conv_layers = nn.ModuleList()
        for i in range(n_message_passings):
            update_layers = []
            for l in range(n_update_layer):
                update_layers.append(nn.LazyLinear(hidden_dims[i]))
                update_layers.append(nn.ReLU())
            if i < n_message_passings-1: # Do not apply a normalisation after the last convolution
                update_layers.append(nn.BatchNorm1d(hidden_dims[i])) # is equivalent to norm after conv
            self.conv_layers.append(GINConv(nn.Sequential(*update_layers)))

        readout_layers = nn.ModuleList()
        for _ in range(n_readout):
            readout_layers.append(nn.Sequential(
                nn.LazyLinear(hidden_dims[-1]),
                nn.ReLU())
            )
        readout_layers.append(nn.Linear(hidden_dims[-1], output_dim))
        self.readout = readout_layers

        self.dropout_rate = dropout
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        for conv in self.conv_layers:
            x = conv(x, edge_index)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)
            x = F.relu(x)
        for i, layer in enumerate(self.readout):
            if i > 0:
                x = F.dropout(x, p=self.dropout_rate, training=self.training)
            x = layer(x)

        x = self.activation(x, dim=1)

        output = x if self.training else (x, x_L)

        return output

class GraphSAGE(nn.Module):
    def __init__(self,
                 hidden_dim: list,
                 n_message_passings: int,
                 output_dim: int,
                 final_activation: torch.functional,
                 pooling: str = None,
                 dropout: float = 0.0):

        hidden_dims = [hidden_dim for _ in range(n_message_passings)]

        super(GraphSAGE, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i in range(0, len(hidden_dims)):
            self.conv_layers.append(SAGEConv(-1, hidden_dims[i]))
        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):

        for conv in self.conv_layers:
            x = torch.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout_rate, training=self.training)
        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)

        x = self.activation(self.readout(x), dim=1)

        output = x if self.training else (x, x_L)

        return output


class GATv2(nn.Module):
    def __init__(self,
                 hidden_dim: list,
                 n_message_passings: int,
                 output_dim: int,
                 final_activation: torch.functional,
                 pooling: str = None,
                 dropout: float = 0.0):

        hidden_dims = [hidden_dim for _ in range(n_message_passings)]

        super(GATv2, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i in range(0, len(hidden_dims)):
            self.conv_layers.append(GATv2Conv(-1, hidden_dims[i]))
        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):

        for conv in self.conv_layers:
            x = torch.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)

        x = self.activation(self.readout(x), dim=1)

        output = x if self.training else (x, x_L)

        return output


class GCNPair(nn.Module):
    def __init__(self,
                 hidden_dim: list,
                 n_message_passings: int,
                 output_dim: int,
                 final_activation: torch.functional,
                 msg_passing_method: str = "gcn",
                 pooling: str = None,
                 dropout: float = 0.0):

        hidden_dims = [hidden_dim for _ in range(n_message_passings)]

        super(GCNPair, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i in range(0, len(hidden_dims)):
            if msg_passing_method == "gcn":
                self.conv_layers.append(GCNConv(-1, hidden_dims[i]))
            elif msg_passing_method == "gat_v2" or msg_passing_method == "gatv2":
                self.conv_layers.append(GATv2Conv(-1, hidden_dims[i]))
            elif msg_passing_method == "sage":
                self.conv_layers.append(SAGEConv(-1, hidden_dims[i]))
        self.norms = nn.ModuleList([PairNorm() for _ in range(n_message_passings-1)]).append(nn.Identity())
        self.readout = nn.Linear(hidden_dims[-1], output_dim)
        self.dropout_rate = dropout
        self.pooling = POOLING_MAPPING[pooling]
        self.activation = ACTIVATION_MAPPING[final_activation]

    def forward(self, x, edge_index, edge_attr=None, batch=None):

        for i, conv in enumerate(self.conv_layers):
            x = conv(x, edge_index)
            x = self.norms[i](x)
            x = torch.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Safe final node representations in evaluation mode
        if not self.training:
            x_L = x

        if self.pooling is not None:
            x = self.pooling(x=x, batch=batch)

        x = self.activation(self.readout(x), dim=1)

        output = x if self.training else (x, x_L)

        return output
