import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import PairNorm

from src.models.model_utils import MESSAGE_PASSING_MAPPING


def init_norm_layer(norm_type: str, input_dim: int):
    """
    Initialize the normalization layer based on the specified type.

    :param norm_type: The type of normalization layer to initialize.
    :param input_dim: The input dimension for the normalization layer.
    :return:
    """
    if norm_type == "batch":
        return nn.BatchNorm1d(input_dim)
    elif norm_type == "layer":
        return nn.LayerNorm(input_dim)
    elif norm_type == "pair":
        return PairNorm(input_dim)
    elif norm_type == "none" or norm_type is None:
        return nn.Identity()
    else:
         raise ValueError(f"Unknown normalization type: {norm_type}")


class MessagePassingLayer(nn.Module):
    def __init__(self, input_dim, output_dim, msg_passing_type, dropout=0.5, norm: str = None):
        super(MessagePassingLayer, self).__init__()

        # Collect parameters for message passing layer
        params = {}

        # Init parameters for different message passing types
        if msg_passing_type == "gine" or msg_passing_type == "gin":
            params.update({
                "nn": nn.Sequential(
                    nn.Linear(input_dim, output_dim),
                    nn.ReLU()
                ),
                "train_eps": True
            })

        if msg_passing_type == "gcn" or msg_passing_type == "gat_v2":
            params.update({
                "in_channels": input_dim,
                "out_channels": output_dim,
                "add_self_loops": True
            })

        if msg_passing_type in ['sage']:
            params.update({
                "in_channels": input_dim,
                "out_channels": output_dim
            })

        if params == {}:
            raise ValueError(f"Message passing type {msg_passing_type} not supported.")

        # Initialize message passing layer
        self.message_passing_layer = MESSAGE_PASSING_MAPPING[msg_passing_type](**params)

        self.norm = init_norm_layer(norm, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        x = self.message_passing_layer(x, edge_index)
        x = self.norm(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x
