"""
MeshGraphNet model definition.
"""

from typing import List, Union

import torch
from torch import Tensor
from dgl import DGLGraph

from physicsnemo.models.module import Module
from physicsnemo.models.gnn_layers.mesh_graph_mlp import MeshGraphMLP
from physicsnemo.models.meshgraphnet.meshgraphnet import MeshGraphNetProcessor, MetaData


class MeshGraphNet(Module):
    """MeshGraphNet-style GNN (node/edge encoders + processor + node decoder).

    Args:
        input_dim_nodes: Input feature dimension for nodes.
        input_dim_edges: Input feature dimension for edges.
        output_dim: Output feature dimension per node.
        processor_size: Number of message-passing steps (processor blocks).
        num_layers_node_processor: # layers in node update MLP inside the processor.
        num_layers_edge_processor: # layers in edge update MLP inside the processor.
        hidden_dim_node_encoder: Hidden/output dim of node encoder MLP.
        num_layers_node_encoder: # hidden layers of node encoder MLP.
        hidden_dim_edge_encoder: Hidden/output dim of edge encoder MLP.
        num_layers_edge_encoder: # hidden layers of edge encoder MLP.
        hidden_dim_node_decoder: Hidden dim of node decoder MLP.
        num_layers_node_decoder: # hidden layers of node decoder MLP.
        aggregation: Aggregation type inside message passing (e.g., "sum").
        do_concat_trick: Whether to use concat trick in the processor (library-specific).
        num_processor_checkpoint_segments: Gradient checkpointing segments (optional).
        activation: Activation name: {"relu", "elu", "leaky"}.

    References:
        Pfaff et al., 2021 (ICML): "Learning Mesh-Based Simulation with Graph Networks."
    """

    def __init__(
        self,
        input_dim_nodes: int,
        input_dim_edges: int,
        output_dim: int,
        processor_size: int = 15,
        num_layers_node_processor: int = 2,
        num_layers_edge_processor: int = 2,
        hidden_dim_node_encoder: int = 128,
        num_layers_node_encoder: int = 2,
        hidden_dim_edge_encoder: int = 128,
        num_layers_edge_encoder: int = 2,
        hidden_dim_node_decoder: int = 128,
        num_layers_node_decoder: int = 2,
        aggregation: str = "sum",
        do_concat_trick: bool = False,
        num_processor_checkpoint_segments: int = 0,
        activation: str = "relu",
    ):
        super().__init__(meta=MetaData())

        if activation=="relu":
            activation_fn = torch.nn.ReLU()
        elif activation=="elu":
            activation_fn = torch.nn.ELU()
        elif activation=="leaky":
            activation_fn = torch.nn.LeakyReLU(0.05)
        else:
            raise ValueError()

        self.edge_encoder = MeshGraphMLP(
            input_dim_edges,
            output_dim=hidden_dim_edge_encoder,
            hidden_dim=hidden_dim_edge_encoder,
            hidden_layers=num_layers_edge_encoder,
            activation_fn=activation_fn,
            norm_type="LayerNorm",
            recompute_activation=False,
        )
        self.node_encoder = MeshGraphMLP(
            input_dim_nodes,
            output_dim=hidden_dim_node_encoder,
            hidden_dim=hidden_dim_node_encoder,
            hidden_layers=num_layers_node_encoder,
            activation_fn=activation_fn,
            norm_type="LayerNorm",
            recompute_activation=False,
        )
        self.node_decoder = MeshGraphMLP(
            hidden_dim_node_encoder,
            output_dim=output_dim,
            hidden_dim=hidden_dim_node_decoder,
            hidden_layers=num_layers_node_decoder,
            activation_fn=activation_fn,
            norm_type=None,
            recompute_activation=False,
        )
        self.processor = MeshGraphNetProcessor(
            processor_size=processor_size,
            input_dim_node=hidden_dim_node_encoder,
            input_dim_edge=hidden_dim_edge_encoder,
            num_layers_node=num_layers_node_processor,
            num_layers_edge=num_layers_edge_processor,
            aggregation=aggregation,
            norm_type="LayerNorm",
            activation_fn=activation_fn,
            do_concat_trick=do_concat_trick,
            num_processor_checkpoint_segments=num_processor_checkpoint_segments,
        )

    def forward(
        self,
        node_features: Tensor,
        edge_features: Tensor,
        graph: Union[DGLGraph, List[DGLGraph]],
    ) -> Tensor:

        edge_features = self.edge_encoder(edge_features)
        node_features = self.node_encoder(node_features)
        x = self.processor(node_features, edge_features, graph)
        x = self.node_decoder(x)
        return x
