import torch
from .datasets import MoleculesBatch

import torch.nn as nn
from torch_geometric.nn import GatedGraphConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import GraphNorm


class MTransformer(nn.Module):
    """
    A transformer model for molecular regression tasks.
    """

    def __init__(
        self,
        num_classes,
        pe_dim,
        d_model,
        nhead,
        num_encoder_layers,
        dim_feedforward,
        dropout=0.1,
    ):
        """
        Initializes the MTransformer model.
        Args:
            num_classes (int): Number of classes for the embedding layer.
            pe_dim (int): Dimension of the positional encoding.
            d_model (int): Dimension of the model.
            nhead (int): Number of heads in the multiheadattention models.
            num_encoder_layers (int): Number of sub-encoder-layers in the encoder.
            dim_feedforward (int): Dimension of the feedforward network model.
            dropout (float, optional): Dropout value. Default is 0.1.
        """
        
        super(MTransformer, self).__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_encoder = nn.Linear(pe_dim, d_model)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, num_encoder_layers
        )
        self.fc_out = nn.Linear(d_model, 1)

    def forward(self, batch: MoleculesBatch) -> torch.Tensor:
        """
        Perform a forward pass through the model.
        Args:
            batch (MoleculesBatch): A batch of molecular data containing the following attributes:
                - x_padded (torch.Tensor): Padded input features.
                - pe_padded (torch.Tensor): Padded positional encodings.
                - mask (torch.Tensor): Mask indicating valid positions in the padded sequences.
        Returns:
            torch.Tensor: The output tensor after processing the input batch through the model.
        """

        x = self.embedding(batch.x_padded.squeeze())
        x = x + self.pos_encoder(batch.pe_padded)
        x = self.transformer_encoder.forward(x, src_key_padding_mask=~batch.mask)
        x = self.fc_out(x)
        x = x*batch.mask.unsqueeze(-1)
        x = x.mean(dim=1).flatten()
        return x
    


class MGatedGCN(nn.Module):
    """
    MGatedGCN is a Gated Graph Convolutional Network for molecular regression tasks.
    """
    
    def __init__(self, num_classes, pe_dim, d_model, num_layers, dropout=0.1):
        """
        Initializes the MGatedGCN model.
        Args:
            num_classes (int): Number of classes for the embedding layer.
            pe_dim (int): Dimension of the positional encoding.
            d_model (int): Dimension of the model.
            num_layers (int): Number of GatedGraphConv layers.
            dropout (float, optional): Dropout rate. Default is 0.1.
        """

        super(MGatedGCN, self).__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_encoder = nn.Linear(pe_dim, d_model)
        self.convs = nn.ModuleList(
            [GatedGraphConv(d_model, 2) for _ in range(num_layers)]
        )
        self.norms = nn.ModuleList(
            [GraphNorm(d_model) for _ in range(num_layers // 3)]
        )
        self.fc_out = nn.Linear(d_model, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch: MoleculesBatch) -> torch.Tensor:
        """
        Perform a forward pass through the model.
        Args:
            batch (MoleculesBatch): A batch of molecular graphs containing node features, 
                                    positional encodings, and edge indices.
        Returns:
            torch.Tensor: The output tensor after processing the input batch through the model.
        """

        x = self.embedding(batch.x.x().squeeze())
        x = x + self.pos_encoder(batch.pe.x())
        edge_index = batch.graphs.edge_index()
        for i, conv in enumerate(self.convs):
            x = x + torch.relu(conv(x, edge_index))
            if (i + 1) % 3 == 0:
                x = self.norms[i // 3](x)
            x = self.dropout(x)
        x = self.fc_out(x)
        x = global_mean_pool(x, batch.x._batch).flatten()
        return x


class MGCN(nn.Module):
    """
    MGCN (Molecular Graph Convolutional Network) model for molecular regression tasks.
    """
    
    def __init__(self, num_classes, pe_dim, d_model, num_layers, dropout=0.1):
        """
        Initializes the MGCN (Molecular Graph Convolutional Network) model.
        Args:
            num_classes (int): Number of classes for the embedding layer.
            pe_dim (int): Dimension of the positional encoding.
            d_model (int): Dimension of the model.
            num_layers (int): Number of GCN layers.
            dropout (float, optional): Dropout rate. Default is 0.1.
        """

        super(MGCN, self).__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_encoder = nn.Linear(pe_dim, d_model)
        self.convs = nn.ModuleList(
            [GCNConv(d_model, d_model) for _ in range(num_layers)]
        )
        self.norms = nn.ModuleList(
            [GraphNorm(d_model) for _ in range(num_layers // 3)]
        )
        self.fc_out = nn.Linear(d_model, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch: MoleculesBatch) -> torch.Tensor:
        """
        Perform a forward pass through the model.
        Args:
            batch (MoleculesBatch): A batch of molecular graphs.
        Returns:
            torch.Tensor: The output tensor after processing the batch through the model.
        """

        x = self.embedding(batch.x.x().squeeze())
        x = x + self.pos_encoder(batch.pe.x())
        edge_index = batch.graphs.edge_index()
        for i, conv in enumerate(self.convs):
            x = x + torch.relu(conv(x, edge_index))
            if (i + 1) % 3 == 0:
                x = self.norms[i // 3](x)
            x = self.dropout(x)
        x = self.fc_out(x)
        x = global_mean_pool(x, batch.x._batch).flatten()
        return x


class MGAT(nn.Module):
    """
    Multi-Head Graph Attention Network (MGAT) for molecular regression tasks.
    """

    def __init__(self, num_classes, pe_dim, d_model, num_layers, heads=1, dropout=0.1):
        """
        Initializes the MGAT model.
        Args:
            num_classes (int): Number of classes for the embedding layer.
            pe_dim (int): Dimension of the positional encoding.
            d_model (int): Dimension of the model.
            num_layers (int): Number of GAT layers.
            heads (int, optional): Number of attention heads. Default is 1.
            dropout (float, optional): Dropout rate. Default is 0.1.
        """

        super(MGAT, self).__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_encoder = nn.Linear(pe_dim, d_model)
        self.convs = nn.ModuleList(
            [GATConv(d_model, d_model // heads, heads=heads) for _ in range(num_layers)]
        )
        self.norms = nn.ModuleList(
            [GraphNorm(d_model) for _ in range(num_layers // 3)]
        )
        self.fc_out = nn.Linear(d_model, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch: MoleculesBatch) -> torch.Tensor:
        """
        Perform a forward pass through the model.
        Args:
            batch (MoleculesBatch): A batch of molecular graphs and associated data.
        Returns:
            torch.Tensor: The output tensor after processing the input batch through the model.
        """

        x = self.embedding(batch.x.x().squeeze())
        x = x + self.pos_encoder(batch.pe.x())
        edge_index = batch.graphs.edge_index()
        for i, conv in enumerate(self.convs):
            x = x + torch.relu(conv(x, edge_index))
            if (i + 1) % 3 == 0:
                x = self.norms[i // 3](x)
            x = self.dropout(x)
        x = self.fc_out(x)
        x = global_mean_pool(x, batch.x._batch).flatten()
        return x
