"""
MiSeGNN model definition.
"""

from typing import List, Union

import torch
from torch import Tensor
from dgl import DGLGraph

from physicsnemo.models.module import Module
from physicsnemo.models.gnn_layers.mesh_graph_mlp import MeshGraphMLP
from physicsnemo.models.meshgraphnet.meshgraphnet import MeshGraphNetProcessor


class MiSeGNN(Module):
    def __init__(
        self,
        input_dim_nodes: int,
        input_dim_edges: int,
        output_dim: int,
        processor_size: int = 15,
        num_layers_node_processor: int = 2,
        num_layers_edge_processor: int = 2,
        hidden_dim_node_encoder: int = 128,
        num_layers_node_encoder: int = 2,
        hidden_dim_edge_encoder: int = 128,
        num_layers_edge_encoder: int = 2,
        hidden_dim_field_decoder: int = 128,
        num_layers_field_decoder: int = 2,
        hidden_dim_error_decoder: int = 128,
        num_layers_error_decoder: int = 2,
        aggregation: str = "sum",
        do_concat_trick: bool = False,
        num_processor_checkpoint_segments: int = 0,
        activation: str = "relu",
    ):
        """
        MiSeGNN: share the same encoder + processor, then split into 2 decoders
        - Decoder_field: output to output_dim
        - Decoder_error: 1-way output (related error)
        """
        super().__init__()

        if activation=="relu":
            activation_fn = torch.nn.ReLU()
        elif activation=="elu":
            activation_fn = torch.nn.ELU()
        elif activation=="leaky":
            activation_fn = torch.nn.LeakyReLU(0.05)
        elif activation == "silu":
            activation_fn = torch.nn.SiLU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

        self.edge_encoder = MeshGraphMLP(
            input_dim_edges,
            output_dim=hidden_dim_edge_encoder,
            hidden_dim=hidden_dim_edge_encoder,
            hidden_layers=num_layers_edge_encoder,
            activation_fn=activation_fn,
            norm_type="LayerNorm",
            recompute_activation=False,
        )

        self.node_encoder = MeshGraphMLP(
            input_dim_nodes,
            output_dim=hidden_dim_node_encoder,
            hidden_dim=hidden_dim_node_encoder,
            hidden_layers=num_layers_node_encoder,
            activation_fn=activation_fn,
            norm_type="LayerNorm",
            recompute_activation=False,
        )

        self.processor = MeshGraphNetProcessor(
            processor_size=processor_size,
            input_dim_node=hidden_dim_node_encoder,
            input_dim_edge=hidden_dim_edge_encoder,
            num_layers_node=num_layers_node_processor,
            num_layers_edge=num_layers_edge_processor,
            aggregation=aggregation,
            norm_type="LayerNorm",
            activation_fn=activation_fn,
            do_concat_trick=do_concat_trick,
            num_processor_checkpoint_segments=num_processor_checkpoint_segments,
        )

        # === Dual‐Head ===
        self.decoder_field = MeshGraphMLP(
            hidden_dim_node_encoder,
            output_dim=output_dim,
            hidden_dim=hidden_dim_field_decoder,
            hidden_layers=num_layers_field_decoder,
            activation_fn=activation_fn,
            norm_type=None,
            recompute_activation=False,
        )

        self.decoder_error = MeshGraphMLP(
            hidden_dim_node_encoder,
            output_dim=1,
            hidden_dim=hidden_dim_error_decoder,
            hidden_layers=num_layers_error_decoder,
            activation_fn=activation_fn,
            norm_type=None,
            recompute_activation=False,
        )

    def forward(
        self,
        node_features: Tensor,
        edge_features: Tensor,
        graph: Union[DGLGraph, List[DGLGraph]],
    ) -> Tensor:
        """
        Input:
          - node_features: Tensor [num_nodes, input_dim_nodes]
          - edge_features: Tensor [num_edges, input_dim_edges]
          - graph: DGLGraph or List[DGLGraph]

        Output:
          - y_pred: Tensor [num_nodes, output_dim]
          - e_pred: Tensor [num_nodes, 1]
        """
        e_enc = self.edge_encoder(edge_features)
        n_enc = self.node_encoder(node_features)

        x_latent = self.processor(n_enc, e_enc, graph)

        y_pred = self.decoder_field(x_latent)
        e_pred = self.decoder_error(x_latent)

        return y_pred, e_pred
