from typing import Dict

import torch
from neuralop.models import FNO1d, GINO


class PDE1DModel(torch.nn.Module):
    """
    FNO1DFixed is a time-independent neural operator model based on the Fourier Neural Operator (FNO) architecture.

    Parameters
    ----------
    model_params : dict
        Dictionary containing the parameters for the FNO1D model, such as number of modes,
        hidden channels, and number of layers.
    debug : bool, optional
        If True, the model parameters will be overwritten with debug parameters for testing purposes.
        Default is False.

    Attributes
    ----------
    model_params : dict
        Parameters for the FNO1D model.
    model : FNO1d
        Instance of the FNO1d model initialized with the provided parameters.
    """

    def __init__(self, model_params: dict, debug: bool = False, ) -> None:
        super().__init__()
        self.model_params = model_params
        self.model_type = model_params['model_type']

        if debug:
            self._overwrite_with_debug_parameters()

        if self.model_type == "FNO":
            self.model = FNO1d(**model_params['fno_params'])
        elif self.model_type == "GINO":
            self.model = GINO(**model_params['gino_params'], gno_coord_dim=1, gno_use_torch_scatter=False)
        else:
            raise ValueError(f"Model {self.model_type} not recognized. Choose either 'FNO' or 'GINO'.")

    def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Forward pass of the FNO1D model.

        Parameters
        ----------
        data : Dict[str, torch.Tensor]
            Input data dictionary containing input data at the key 'x'.

        Returns
        -------
        torch.Tensor
            Output of the FNO1D model.
        """

        if self.model_type == "FNO":
            output = self.model(data['x'])
        else:
            output_queries = data.get('output_grid', data['input_grid'])[0]
            output = self.model(input_geom=data['input_grid'][:1], latent_queries=data['latent_grid'][:1],
                                output_queries=output_queries, x=data['x'])
        return output

    def _overwrite_with_debug_parameters(self) -> None:
        """
        Overwrite the model parameters with debug parameters.

        Parameters
        ----------
        debug_params : torch.Tensor
            Debug parameters to overwrite the model parameters.
        """

        if self.model_type == "FNO":
            self.model_params['fno_params']['n_modes_height'] = 4
            self.model_params['fno_params']['hidden_channels'] = 6
            self.model_params['fno_params']['n_layers'] = 2
        else:
            self.model_params['gino_params']['fno_n_modes'] = [4]
            self.model_params['gino_params']['fno_hidden_channels'] = 6
            self.model_params['gino_params']['fno_n_layers'] = 2
