from dataclasses import dataclass, field, asdict
from typing import Literal, Tuple, Union, Dict, Any, Optional
import torch
from config import DatasetCfg, get_dataset_cfg

# ---------------------------
# Common fields across models
# ---------------------------
@dataclass
class CommonConfig:
    """Fields shared by all model configurations."""
    in_channels: int                  # Number of input channels/features
    out_channels: int                 # Number of output channels/features
    device: torch.device
    seed: int = 0                     # Random seed for reproducibility


# ---------------------------
# Model-specific configurations
# ---------------------------
@dataclass
class FNOConfig(CommonConfig):
    """Configuration for an FNO (Fourier Neural Operator) model."""
    model_name: Literal["FNO"] = "FNO"                   # Discriminator tag
    backend: Literal["native", "neuralop"] = "native"    # "neuralop" / "native"
    hidden_channels: int = 64                            # Hidden feature width
    n_modes: Tuple[int, ...] = (12, 12)                  # Fourier modes per spatial dim (e.g., 2D -> (m1, m2); 3D -> (m1, m2, m3))
    n_layers: int = 4                                    # Number of FNO blocks/layers
    shapelist: Tuple[int, ...] = (64, 64)                # Spatial resolution per dim; product equals number of grid points

    time_input: bool = False
    pos_emb: bool = True
    ref: int = 8                                         # number of reference points for unified pos embedding

    activation: str = "gelu"
    

@dataclass
class U_NetConfig(CommonConfig):
    """Configuration for a U-Net model."""
    model_name: Literal["U_Net"] = "U_Net"               # Discriminator tag
    hidden_channels: int = 64                            # Hidden feature width
    n_modes: int = 12
    shapelist: Tuple[int, ...] = (64, 64)                # Spatial resolution per dim; product equals number of grid points

    time_input: bool = False
    pos_emb: bool = True
    ref: int = 8                                         # number of reference points for unified pos embedding

    activation: str = "gelu"


@dataclass
class U_NOConfig(CommonConfig):
    """Configuration for a U_NO model"""
    model_name: Literal["U_NO"] = "U_NO"
    hidden_channels: int = 64                            # Hidden feature width
    n_modes: int = 12
    shapelist: Tuple[int, ...] = (64, 64)                # Spatial resolution per dim; product equals number of grid points

    time_input: bool = False
    pos_emb: bool = True
    ref: int = 8                                         # number of reference points for unified pos embedding

    activation: str = "gelu"


@dataclass
class GKTConfig(CommonConfig):
    """Configuration for an GKT (Galerkin Transformer) model."""
    model_name: Literal["GKT"] = "GKT"                   # Discriminator tag
    hidden_channels: int = 64                            # Hidden feature width
    n_layers: int = 4                                    # Number of FNO blocks/layers
    n_heads: int = 8
    dropout: float = 0.0
    mlp_ratio: int = 1                                   # mlp ratio for feedforward layers
    shapelist: Tuple[int, ...] = (64, 64)                # Spatial resolution per dim; product equals number of grid points

    time_input: bool = False
    pos_emb: bool = True
    ref: int = 8                                         # number of reference points for unified pos embedding

    activation: str = "gelu"


@dataclass
class F_FNOConfig(CommonConfig):
    """Configuration for an F_FNO model."""
    model_name: Literal["F_FNO"] = "F_FNO"               # Discriminator tag
    hidden_channels: int = 64                            # Hidden feature width
    n_modes: Tuple[int, ...] = (12, 12)                  # Fourier modes per spatial dim (e.g., 2D -> (m1, m2); 3D -> (m1, m2, m3))
    n_layers: int = 4                                    # Number of FNO blocks/layers
    shapelist: Tuple[int, ...] = (64, 64)                # Spatial resolution per dim; product equals number of grid points

    time_input: bool = False
    pos_emb: bool = True
    ref: int = 8                                         # number of reference points for unified pos embedding

    activation: str = "gelu"

# -------------------------------------------------------------------------------------------------
# Specialize for DINO -----------------------------------------------------------------------------
# —— submodule dataclass ——
@dataclass(slots=True)
class FieldDecoderParams:
    fourier_hidden_dim: int = 64
    code_dim: int = 64
    n_fourier_layers: int = 3
    input_scale: int = 64
    chunk_t: int = 0
    use_sigmoid: bool = False
    def to_params(self) -> Dict:
        return asdict(self)

@dataclass(slots=True)
class LatentODEParams:
    state_dim: int
    code_dim: int
    hidden_dim: int = 256
    num_layers: int = 3
    nl: Literal["relu", "gelu", "swish", "tanh"] = "swish"
    def to_params(self) -> Dict:
        return asdict(self)

@dataclass(slots=True)
class LatentProcessParams:
    state_dim: int
    code_dim: int
    latent_type: str = "neural_ode"
    solver: str = "rk4"
    def to_params(self) -> Dict:
        return asdict(self)

@dataclass(slots=True)
class SetEncoderParams:
    code_size: int
    n_cond: int
    hidden_size: int = 1024
    def to_params(self) -> Dict:
        return asdict(self)

# —— Model Factory for DINO ——
@dataclass(slots=True)
class DINOParamBundle:
    field_decoder: FieldDecoderParams
    latent_ode: LatentODEParams
    latent_process: LatentProcessParams
    set_encoder: SetEncoderParams
    _state_dim: int
    _code_dim: int
    _n_frames_cond: int 
    use_delay: bool     

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def code_dim(self) -> int:
        return self._code_dim

    @property
    def latent_dim(self) -> int:
        return self._state_dim * self._code_dim

    @property
    def n_frames_cond(self) -> int:          
        return self._n_frames_cond

    @property
    def dyn_code_dim(self) -> int:
        return self.latent_ode.code_dim

    def as_model_kwargs(self, include_meta: bool = True) -> Dict[str, Any]:
        """Return per-module kwargs; optionally include top-level meta fields."""
        out = {
            "field_decoder": self.field_decoder.to_params(),
            "latent_ode": self.latent_ode.to_params(),
            "latent_process": self.latent_process.to_params(),
            "set_encoder": self.set_encoder.to_params(),
        }
        if include_meta:
            out.update({
                "state_dim": self.state_dim,
                "code_dim": self.code_dim,
                "latent_dim": self.latent_dim,
                "n_frames_cond": self.n_frames_cond,
                "use_delay": self.use_delay,
                "dyn_code_dim": self.dyn_code_dim,
            })
        return out

    @staticmethod
    def from_args(
        *,
        dataset_cfg: DatasetCfg,
        code_dim: int,
        n_frames_cond: int,    # number of conditional frames
        # ------------------------------------------------
        fourier_hidden_dim: int = 64,
        n_fourier_layers: int = 3,
        input_scale: Optional[int] = None,  # None -> use dataset_cfg.INPUT_SCALE
        chunk_t: int = 0,
        use_sigmoid: bool = False,
        ode_hidden_dim: int = 256,
        ode_num_layers: int = 3,
        ode_nl: str = "swish",
        latent_type: str = "neural_ode",
        solver: str = "rk4",
        set_hidden_size: int = 1024,
    ) -> "DINOParamBundle":
        STATE_DIM = dataset_cfg.STATE_DIM
        INPUT_SCALE = dataset_cfg.INPUT_SCALE if input_scale is None else int(input_scale)

        use_delay = False if int(n_frames_cond) < 2 else True
        dyn_code_dim = int(code_dim) * 2 if use_delay else int(code_dim)
        n_cond = max(0, int(n_frames_cond) - 1)

        field_decoder = FieldDecoderParams(
            fourier_hidden_dim=fourier_hidden_dim,
            code_dim=int(code_dim),
            n_fourier_layers=n_fourier_layers,
            input_scale=INPUT_SCALE,
            chunk_t=chunk_t,
            use_sigmoid=use_sigmoid,
        )
        latent_ode = LatentODEParams(
            state_dim=STATE_DIM,
            code_dim=dyn_code_dim,
            hidden_dim=ode_hidden_dim,
            num_layers=ode_num_layers,
            nl=ode_nl,  # type: ignore[arg-type]
        )
        latent_process = LatentProcessParams(
            state_dim=STATE_DIM,
            code_dim=dyn_code_dim,
            latent_type=latent_type,
            solver=solver,
        )
        set_encoder = SetEncoderParams(
            code_size=int(code_dim) * STATE_DIM,
            n_cond=n_cond,
            hidden_size=set_hidden_size,
        )
        return DINOParamBundle(
            field_decoder=field_decoder,
            latent_ode=latent_ode,
            latent_process=latent_process,
            set_encoder=set_encoder,
            _state_dim= int(STATE_DIM),
            _code_dim=int(code_dim),
            _n_frames_cond=int(n_frames_cond),
            use_delay=use_delay
        )


# -----------------------------------------------------------------------------------------------------
# Union of possible model configs (discriminated by `model_name`)
Config = Union[FNOConfig, U_NetConfig, U_NOConfig, GKTConfig, F_FNOConfig, DINOParamBundle]

def make_config(**kw) -> Config:
    """
    Factory that constructs a concrete config based on `model_name`.

    Args:
        **kw: Keyword arguments for the dataclass fields. Must include `model_name`.

    Returns:
        A concrete dataclass instance (FNOConfig or UNetConfig).

    Raises:
        ValueError: If `model_name` is missing or unknown.
    """
    name = kw.get("model_name")
    if name == "FNO":
        return FNOConfig(**kw)
    elif name == "U_Net":
        return U_NetConfig(**kw)
    elif name == "U_NO":
        return U_NOConfig(**kw)
    elif name == "GKT":
        return GKTConfig(**kw)
    elif name == "F_FNO":
        return F_FNOConfig(**kw)
 
    elif name == "DINO":
        dataset = kw.pop("dataset")             
        code_dim = kw.pop("code_dim")         
        n_frames_cond = kw.pop("n_frames_cond")  

        dataset_cfg = get_dataset_cfg(name=dataset)
        allowed = {
            "fourier_hidden_dim", "n_fourier_layers", "input_scale", "chunk_t", "use_sigmoid",
            "ode_hidden_dim", "ode_num_layers", "ode_nl", "latent_type", "solver", "set_hidden_size",
        }
        overrides = {k: kw[k] for k in list(kw.keys()) if k in allowed}

        return DINOParamBundle.from_args(
            dataset_cfg=dataset_cfg,
            code_dim=code_dim,
            n_frames_cond=n_frames_cond,
            **overrides
        )
    else:
        raise ValueError(f"Unknown model_name: {name}")


from baselines import FNO, U_NO, U_Net, GKT, F_FNO

def get_model(model_name: str, model_cfg: Config):
    model_dict = {
        'FNO': FNO,
        'U_NO': U_NO,
        'U_Net': U_Net, 
        'GKT': GKT,
        'F_FNO': F_FNO,
    }

    return model_dict[model_name].Model(model_cfg)

# -----------------------------------------------------------------------------------------------------