"""Model factory for creating neural operator models.

This module provides functions to create different types of neural operator models
based on configuration, including FNO1d, FNO2d, and custom variants.
"""
from typing import Optional

import torch.nn as nn

from src.config import Config
from src.models.pde_model import PDE1DModel


def build_model(config: Config, state_dict: Optional[dict] = None) -> nn.Module:
    """Build a model based on the configuration.
    
    Parameters
    ----------
    config : Config
        Configuration object containing model parameters.

    state_dict : Optional[dict], optional
        Pre-trained weights to load into the model (default is None, which means no weights are
        loaded).
        
    Returns
    -------
    nn.Module
        Instantiated model.
        
    Raises
    ------
    ValueError
        If the model name is not supported.
    """

    model = PDE1DModel(
        debug=config.debug,
        model_params=config.model.model_params
    )
    if state_dict is not None:
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        _ = state_dict.pop('_metadata', None)
        model.load_state_dict(state_dict, strict=True)

    model = model.to(config.training.device)

    return model
