"""
Model factory utilities.

This module creates the model, optimizer, and loss function based on CLI args.
"""

import torch

from .mgn import MeshGraphNet
from .mise_gnn import MiSeGNN


def create_model_optim_loss(args):
    """Create (model, optimizer, loss_fn) from parsed arguments.

    Args:
        args: Parsed CLI arguments (argparse.Namespace or similar) containing
            model hyperparameters and training settings (e.g., mode, lr, dims).

    Returns:
        tuple: (model, optimizer, loss_fn)
            - model: torch.nn.Module
            - optimizer: torch.optim.Optimizer
            - loss_fn: torch.nn.Module
    """
    common_kwargs = {
        "input_dim_nodes":args.input_dim_nodes,
        "input_dim_edges":args.input_dim_edges,
        "output_dim":args.output_dim,
        "processor_size":args.processor_size,
        "num_layers_node_processor":args.num_layers_node_processor,
        "num_layers_edge_processor":args.num_layers_edge_processor,
        "hidden_dim_node_encoder":args.hidden_dim_node_encoder,
        "num_layers_node_encoder":args.num_layers_node_encoder,
        "hidden_dim_edge_encoder":args.hidden_dim_edge_encoder,
        "num_layers_edge_encoder":args.num_layers_edge_encoder,
        "aggregation":args.aggregation,
        "do_concat_trick":False,
        "num_processor_checkpoint_segments":0,
        "activation":args.activation
    }

    if args.mode in ("mgn", "mgn_tree"):
        model = MeshGraphNet(
            **common_kwargs,
            hidden_dim_node_decoder=args.hidden_dim_node_decoder,
            num_layers_node_decoder=args.num_layers_node_decoder,
        )

    elif args.mode == "mise_gnn":
        model = MiSeGNN(
            **common_kwargs,
            hidden_dim_field_decoder=args.hidden_dim_node_decoder,
            num_layers_field_decoder=args.num_layers_node_decoder,
            hidden_dim_error_decoder=args.hidden_dim_node_decoder,
            num_layers_error_decoder=args.num_layers_node_decoder,
        )

    else:
        raise ValueError(f"Unsupported mode: {args.mode}")

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    loss_fn = torch.nn.MSELoss()

    return model, optimizer, loss_fn
