import argparse
import torch
import torch.nn as nn

from dataclasses import dataclass, field, asdict
from typing import Literal, Dict, Any, Optional, Tuple
import torch

def parse_args():
    parser = argparse.ArgumentParser()
    # ───────────────────────────────────────── Parse + cross‑parameter validation ──────────────────────
    args = parser.parse_args()
    return args


######################################################################################################################
######################################################################################################################
"""
DATASET config
"""
from dataclasses import dataclass

@dataclass(frozen=True)
class DatasetCfg:
    STATE_DIM: int
    SPATIAL_DIM: int
    GRID_DIM: int
    INPUT_SCALE: int
    SHAPELIST: tuple[int, int]
    DATA_PATH: str

_CFG_TABLE = {
    "ns": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=64, SHAPELIST=(64, 64),
        DATA_PATH="./data/ns_V1e-3_N5000_T50.mat"
    ),
    "ns_1e-3": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=64, SHAPELIST=(64, 64),
        DATA_PATH="./data/ns_V1e-3_N5000_T50.mat"
    ),
    "ns_1e-4": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=64, SHAPELIST=(64, 64),
        DATA_PATH="./data/ns_V1e-4_N10000_T30.mat"
    ),
    "ns_1e-5": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=128, SHAPELIST=(64, 64),
        DATA_PATH="./data/ns_V1e-5_N1200_T20.mat"
    ),
    "wave": DatasetCfg(
        STATE_DIM=2, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=128, SHAPELIST=(64, 64),
        DATA_PATH="./data/wave.h5"
    ),
    "shallow_water": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=64, SHAPELIST=(64, 64),
        DATA_PATH="./data/shallow_water.h5"
    ),
    "sst": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=128, SHAPELIST=(64, 64),
        DATA_PATH="./data/sst_T20_N1000.pt"
    ),
    "ks": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=1, GRID_DIM=1, INPUT_SCALE=128, SHAPELIST=(128,),
        DATA_PATH="./data/ks_dataset.npz"
    ),
    "era5": DatasetCfg(
        STATE_DIM=1, SPATIAL_DIM=2, GRID_DIM=2, INPUT_SCALE=180, SHAPELIST=(180, 360),
        DATA_PATH="./data/ERA5_N550_T20.npz"
    )
}



def get_dataset_cfg(name: str) -> DatasetCfg:
    """Return immutable config for a given dataset name."""
    try:
        return _CFG_TABLE[name.lower()]
    except KeyError:
        raise ValueError(f"Unknown dataset '{name}'. "
                         f"Available: {list(_CFG_TABLE.keys())}")
    

######################################################################################################################
######################################################################################################################
"""
MODEL config
"""
from memKNO.encoder import LatentGlobalEncoder2D, SetEncoder2D
from memKNO.encoder import EncoderRebuttal
from memKNO.decoder import FieldDecoder, CrossFormerDecoder2D, FourierDecoder
from memKNO.latent import LatentProcess
from memKNO.latent_discrete import LatentProcessDiscrete

# —— submodule dataclass ——
@dataclass(slots=True)
class EncoderParams:
    input_channels: int
    in_emb_dim: int = 128
    token_dim: int = 64
    heads: int = 4
    spatial_depth: int = 3
    dim_head: int | None = 32
    mlp_dim: int | None = None
    attn_type: Literal["galerkin", "fourier"] = "galerkin"
    dropout: float = 0.0
    relative_emb_dim: int = 2
    min_freq: float = 1/64
    scale_spatial: int | None = None
    use_ln: bool = True
    latent_tokens: int = 4
    latent_depth: int = 2
    use_latent_ln: bool = True
    use_latent_pos: bool = True
    scale_latent: int = 8
    def to_params(self) -> Dict:
        return asdict(self)
    

@dataclass(slots=True)
class EncoderRebuParams:
    device: torch.device

    input_channels: int

    shapelist: Tuple[int, ...] = (64, 64)
    pos_emb: bool = True
    ref: int = 8
    activation: str = "gelu"
    
    in_emb_dim: int = 128
    token_dim: int = 64
    heads: int = 4
    spatial_depth: int = 3
    dim_head: int | None = 32
    mlp_dim: int | None = None
    attn_type: Literal["galerkin", "fourier"] = "galerkin"
    dropout: float = 0.0
    relative_emb_dim: int = 2
    min_freq: float = 1/64
    scale_spatial: int | None = None
    use_ln: bool = True
    latent_tokens: int = 4
    latent_depth: int = 2
    use_latent_ln: bool = True
    use_latent_pos: bool = True
    scale_latent: int = 8
    def to_params(self) -> Dict:
        return asdict(self)


@dataclass(slots=True)
class SetEncoderParams:
    input_channels: int
    pos_emb_dim: int
    pos_emb_type: str = "trainable"
    pos_hidden: int = 256
    val_hidden: int = 128
    set_dim: int = 128
    set_hidden: int = 128
    num_heads: int = 4
    num_inds: int = 64 
    token_dim: int = 64
    latent_tokens: int = 4                # K
    use_ln: bool = True
    fourier_max_freq: float = 16.0
    dropout: float = 0.1
    def to_params(self) -> Dict:
        return asdict(self)


@dataclass(slots=True)
class FieldDecoderParams:
    fourier_hidden_dim: int = 64
    code_dim: int = 64
    n_fourier_layers: int = 3
    input_scale: float = 64
    chunk_t: int = 0
    use_sigmoid: bool = False
    mlp_in: bool = False
    mlp_layers: int = 2
    mlp_act: str = "gelu"
    def to_params(self) -> Dict:
        return asdict(self)
    

@dataclass(slots=True)
class CrossDecoderParams:
    atten_blocks: int = 1
    heads: int = 4
    latent_channels: int = 64
    out_channels: int = 1
    random_fourier_scale: float = 8.0
    mlp_hidden_dim: int = 256
    dropout: float = 0.0
    def to_params(self) -> Dict:
        return asdict(self)
    

@dataclass(slots=True)
class FourierDecoderParams:
    grid_dim: int = 2
    fourier_hidden_dim: int = 256
    latent_dim: int = 128
    out_dim : int = 1
    n_fourier_layers: int = 3
    input_scale: float = 64
    modmlp_layers: int = 2
    modmlp_act: str = "gelu"
    use_freq_scale: bool = True
    use_phase: bool = True
    use_coord_scale: bool = False
    lora_rank: int = 2
    lora_scale: float = 1.0
    def to_params(self) -> Dict:
        return asdict(self)


@dataclass(slots=True)
class LatentProcessParams:
    state_dim: int 
    code_dim: int
    latent_type: str = "linear+memory"
    solver: str = "dopri5"
    # memory settings
    memory_dim: int | None = None
    # encoder/decoder MLP hyperparams
    enc_hidden_dim: int | None = None
    dec_hidden_dim: int | None = None
    enc_layers: int = 2
    dec_layers: int = 2
    nl: str = "swish"
    # linear operator parameterization (Scheme A: dense pH)
    linear_param: str = "pH_dense"              # {"free", "pH_dense"}
    ph_osc_dims: int = 0                        # oscillatory subspace dim (>=0)
    ph_epsP: float = 1e-6                       # ε for P = L^T L + ε I
    ph_min_freq: float = 1e-3                   # frequency lower bound for oscillatory modes
    return_aux: bool = False                    # return ||phi_dec(m)||^2
    def to_params(self) -> Dict:
        return asdict(self)


@dataclass(slots=True)
class LatentProcessDiscreteParams:
    state_dim: int
    code_dim: int
    memory_dim: Optional[int] = None
    memory_type: str = "leaky"          # {'leaky','gru','lstm'}
    process_type: str = "discrete"      # {'discrete','ode_rnn'}
    # dec/enc MLPs for leaky backend
    enc_hidden_dim: int = 128
    enc_layers: int = 2
    dec_hidden_dim: int = 128
    dec_layers: int = 2
    nl: str = "swish"
    # RNN config (for gru/lstm)
    rnn_layers: int = 2 
    rnn_dropout: float = 0.0 
    # gate
    gate_per_dim: bool = True
    # init options
    init_tau_steps: Optional[float] = None    # for leaky: initialize gamma from tau,
    use_layer_norm: bool = True
    context_window: int = 3
    window_pad: str = "repeat"                # "repeat" / "zero"
    augment: bool = False                    
    augment_variant: str = "history"          # "history" / "current"
    rnn_hidden: int = 256
    def to_params(self) -> Dict:
        return asdict(self)



# —— Model Factory for MemKNO ——
@dataclass(slots=True)
class MemKNOParamBundle:
    encoder: EncoderParams
    encoder_rebuttal: EncoderRebuParams
    set_encoder: SetEncoderParams
    field_decoder: FieldDecoderParams
    cross_decoder: CrossDecoderParams
    fourier_decoder: FourierDecoderParams
    latent_process: LatentProcessParams
    latent_process_discrete: LatentProcessDiscreteParams
    
    _state_dim: int
    _latent_dim: int
    _code_dim: int
    _n_frames_cond: int 
    _input_channels: int
    _latent_type: str

    @property
    def state_dim(self) -> int:
        return self._state_dim

    @property
    def code_dim(self) -> int:
        return self._code_dim

    @property
    def latent_dim(self) -> int:
        return self._latent_dim

    @property
    def n_frames_cond(self) -> int:          
        return self._n_frames_cond
    
    @property
    def input_channels(self) -> int:
        return self._input_channels
    
    @property
    def latent_type(self) -> str:
        return self._latent_type

    def as_model_kwargs(self, include_meta: bool = True) -> Dict[str, Any]:
        """Return per-module kwargs; optionally include top-level meta fields."""
        out = {
            "encoder": self.encoder.to_params(),
            "set_encoder": self.set_encoder.to_params(),
            "encoder_rebu": self.encoder_rebuttal.to_params(),
            "field_decoder": self.field_decoder.to_params(),
            "cross_decoder": self.cross_decoder.to_params(),
            "fourier_decoder": self.fourier_decoder.to_params(),
            "latent_process": self.latent_process.to_params(),
            "latent_process_discrete": self.latent_process_discrete.to_params(),
        }
        if include_meta:
            out.update({
                "state_dim": self.state_dim,
                "latent_dim": self.latent_dim,
                "code_dim": self.code_dim,
                "n_frames_cond": self.n_frames_cond,
                "input_channels": self.input_channels,
                "latent_type": self.latent_type,
            })
        return out

    @staticmethod
    def from_args(
        *,
        dataset_cfg: DatasetCfg,
        n_frames_cond: int,    # number of conditional frames
        # ------------------------------------------------
        fourier_hidden_dim: int = 64,
        n_fourier_layers: int = 3,
        input_scale: Optional[int] = None,  # None -> use dataset_cfg.INPUT_SCALE
        chunk_t: int = 0,
        use_sigmoid: bool = False,
        mlp_in: bool = False,
        in_mlp_layers: int = 2,
        in_mlp_act: str = "gelu",
        # ------------------------------------------------
        shapelist: Tuple[int, ...] = (64, 64),
        pos_emb: bool = True,
        ref: int = 8,
        activation: str = "gelu",
        device: torch.device = "cpu",
        # ------------------------------------------------
        in_emb_dim: int = 128,
        token_dim: int = 64,
        enc_heads: int = 4,
        spatial_depth: int = 3,
        dim_head: int | None = 32,
        min_freq: float = 1/64,
        latent_tokens: int = 4,
        latent_depth: int = 2,
        # ------------------------------------------------
        pos_emb_dim: int = 64,
        pos_emb_type: str = "trainable",
        pos_hidden: int = 128,
        val_hidden: int = 128,
        set_dim: int = 128,
        set_hidden: int = 128,
        num_inds: int = 64,
        use_ln: bool = True,
        fourier_max_freq: float = 16.0,
        dropout: float = 0.1,
        # ------------------------------------------------
        dec_atten_blocks: int = 1,
        dec_heads: int = 4,
        random_fourier_scale: float = 8.0,
        dec_mlp_dim: int = 256,
        # ------------------------------------------------
        modmlp_layers: int = 2,
        modmlp_act: str = "gelu",
        use_freq_scale: bool = False,
        use_phase: bool = False,
        use_coord_scale: bool = False,
        lora_rank: int = 0,
        lora_scale: int = 1.0,
        # ------------------------------------------------
        latent_type: str = "linear+memory",
        solver: str = "rk4",
        # memory settings
        memory_dim: int | None = None,
        # encoder/decoder MLP hyperparams
        ode_enc_hidden_dim: int | None = None,
        ode_dec_hidden_dim: int | None = None,
        ode_enc_layers: int = 2,
        ode_dec_layers: int = 2,
        ode_nl: str = "swish",
        # linear operator parameterization (Scheme A: dense pH)
        linear_param: str = "pH_dense",              # {"free", "pH_dense"}
        ph_osc_dims: int = 0,                        # oscillatory subspace dim (>=0)
        ph_epsP: float = 1e-6,                       # ε for P = L^T L + ε I
        ph_min_freq: float = 1e-3,                   # frequency lower bound for oscillatory modes
        return_aux: bool = True,                     # return ||phi_dec(m)||^2
        # ------------------------------------------------
        memory_type: str = "leaky",           # {'leaky','gru','lstm'}
        process_type: str = "discrete",
        rnn_layers: int = 2, 
        rnn_dropout: float = 0.0, 
        # gate
        gate_per_dim: bool = True,
        # init options
        init_tau_steps: Optional[float] = None,  # for leaky: initialize gamma from tau,
        latent_ln: bool = True,
        context_window: int = 4,                    # d
        window_pad: str = "repeat",                 # {'repeat','zero'} padding for the initial window
        augment: bool = False,                      # channel-wise augmentation
        augment_variant: str = "history",           # {'history','current'}
        rnn_hidden: int = 256,
    ) -> "MemKNOParamBundle":
        
        STATE_DIM = dataset_cfg.STATE_DIM
        SPATIAL_DIM = dataset_cfg.SPATIAL_DIM
        GRID_DIM = dataset_cfg.GRID_DIM
        INPUT_SCALE = dataset_cfg.INPUT_SCALE if input_scale is None else int(input_scale)

        input_channels = int(n_frames_cond * STATE_DIM + SPATIAL_DIM)
        latent_dim = int(token_dim * latent_tokens)
        assert latent_dim % STATE_DIM == 0
        code_dim = latent_dim // STATE_DIM

        encoder = EncoderParams(
            input_channels=input_channels,
            in_emb_dim=in_emb_dim,
            token_dim=token_dim,
            heads=enc_heads,
            spatial_depth=spatial_depth,
            dim_head=dim_head,
            min_freq=min_freq,
            latent_tokens=latent_tokens,
            latent_depth=latent_depth,
            relative_emb_dim=SPATIAL_DIM,
        )

        encoder_rebuttal = EncoderRebuParams(
            shapelist=shapelist,
            pos_emb=pos_emb,
            ref=ref,
            activation=activation,
            device=device,
            input_channels=input_channels,
            in_emb_dim=in_emb_dim,
            token_dim=token_dim,
            heads=enc_heads,
            spatial_depth=spatial_depth,
            dim_head=dim_head,
            min_freq=min_freq,
            latent_tokens=latent_tokens,
            latent_depth=latent_depth,
            relative_emb_dim=SPATIAL_DIM,
        )

        set_encoder = SetEncoderParams(
            input_channels=input_channels,
            pos_emb_dim=pos_emb_dim,
            pos_emb_type=pos_emb_type,
            pos_hidden=pos_hidden,
            val_hidden=val_hidden,
            set_dim=set_dim,
            set_hidden=set_hidden,
            num_heads=enc_heads,
            num_inds=num_inds,
            token_dim=token_dim,
            latent_tokens=latent_tokens,
            use_ln=use_ln, 
            fourier_max_freq=fourier_max_freq,
            dropout=dropout
        )

        field_decoder = FieldDecoderParams(
            fourier_hidden_dim=fourier_hidden_dim,
            code_dim=int(code_dim),
            n_fourier_layers=n_fourier_layers,
            input_scale=INPUT_SCALE,
            chunk_t=chunk_t,
            use_sigmoid=use_sigmoid,
            mlp_in=mlp_in,
            mlp_layers=in_mlp_layers,
            mlp_act=in_mlp_act
        )

        cross_decoder = CrossDecoderParams(
            atten_blocks=dec_atten_blocks,
            heads=dec_heads,
            latent_channels=token_dim,
            out_channels=STATE_DIM,
            random_fourier_scale=random_fourier_scale,
            mlp_hidden_dim=dec_mlp_dim
        )

        fourier_decoder = FourierDecoderParams(
            grid_dim=GRID_DIM,
            fourier_hidden_dim=fourier_hidden_dim,
            latent_dim=latent_dim,
            out_dim=STATE_DIM,
            n_fourier_layers=n_fourier_layers,
            input_scale=INPUT_SCALE,
            modmlp_layers=modmlp_layers,
            modmlp_act=modmlp_act,
            use_freq_scale=use_freq_scale,
            use_phase=use_phase,
            use_coord_scale=use_coord_scale,
            lora_rank=lora_rank,
            lora_scale=lora_scale
        )

        latent_process = LatentProcessParams(
            state_dim=STATE_DIM,
            code_dim=int(code_dim),
            latent_type=latent_type,
            solver=solver,
            memory_dim=memory_dim,
            enc_hidden_dim=ode_enc_hidden_dim,
            dec_hidden_dim=ode_dec_hidden_dim,
            enc_layers=ode_enc_layers,
            dec_layers=ode_dec_layers,
            nl=ode_nl,
            linear_param=linear_param,
            ph_osc_dims=ph_osc_dims,
            ph_epsP=ph_epsP,
            ph_min_freq=ph_min_freq,
            return_aux=return_aux,
        )

        latent_process_discrete = LatentProcessDiscreteParams(
            state_dim=STATE_DIM, 
            code_dim=int(code_dim),
            memory_dim=memory_dim,
            memory_type=memory_type, 
            process_type=process_type,
            enc_hidden_dim=ode_enc_hidden_dim,
            enc_layers=ode_enc_layers,
            dec_hidden_dim=ode_dec_hidden_dim,
            dec_layers=ode_dec_layers,
            nl=ode_nl,
            rnn_layers=rnn_layers,
            rnn_dropout=rnn_dropout,
            gate_per_dim=gate_per_dim,
            init_tau_steps=init_tau_steps,
            use_layer_norm=latent_ln,
            context_window=context_window,
            window_pad=window_pad,
            augment=augment, augment_variant=augment_variant, 
            rnn_hidden=rnn_hidden
        )
        
        return MemKNOParamBundle(
            encoder=encoder,
            encoder_rebuttal=encoder_rebuttal,
            set_encoder=set_encoder,
            field_decoder=field_decoder,
            cross_decoder=cross_decoder,
            fourier_decoder=fourier_decoder,
            latent_process=latent_process,
            latent_process_discrete=latent_process_discrete,
            _state_dim = int(STATE_DIM),
            _latent_dim=int(latent_dim),
            _code_dim=int(code_dim),
            _n_frames_cond=int(n_frames_cond),
            _input_channels=int(input_channels),
            _latent_type=latent_type
        )


def make_config(**kw) -> MemKNOParamBundle:
    name = kw.get("model_name")    # "memKNO"
    if name == "memKNO":
        dataset = kw.pop("dataset")                 
        n_frames_cond = kw.pop("n_frames_cond")  

        dataset_cfg = get_dataset_cfg(name=dataset)
        allowed = {
            "fourier_hidden_dim", "n_fourier_layers", "input_scale", "chunk_t", "use_sigmoid", 

            "mlp_in", "in_mlp_layers", "in_mlp_act",
            "in_emb_dim", "token_dim", "enc_heads", "spatial_depth", "dim_head", "min_freq", "latent_tokens", "latent_depth",

            "shapelist", "pos_emb", "ref", "device", "activation",

            "pos_emb_dim", "pos_emb_type", "pos_hidden", "val_hidden", "set_dim", "set_hidden", "num_inds",
            "use_ln", "dropout", "fourier_max_freq",

            "dec_atten_blocks", "dec_heads", "random_fourier_scale", "dec_mlp_dim",
            "modmlp_layers", "modmlp_act", "use_freq_scale", "use_phase", "use_coord_scale",
            "lora_rank", "lora_scale",

            "latent_type", "solver", "memory_dim",
            "ode_enc_hidden_dim", "ode_dec_hidden_dim", "ode_enc_layers", "ode_dec_layers", "ode_nl",
            "linear_param", "ph_osc_dims", "ph_epsP", "ph_min_freq", "return_aux",  

            "memory_type", "process_type", "rnn_layers", "rnn_dropout",
            "gate_per_dim", "init_tau_steps", "latent_ln",
            "context_window", "window_pad", "augment", "augment_variant", "rnn_hidden"
        }
        overrides = {k: kw[k] for k in list(kw.keys()) if k in allowed}
        return MemKNOParamBundle.from_args(
            dataset_cfg=dataset_cfg,
            n_frames_cond=n_frames_cond,
            **overrides
        )
    else:
        raise ValueError(f"Unknown model_name: {name}")


from memKNO.encoder import LatentGlobalEncoder2D
def build_encoder(model_cfg: EncoderParams):
    encoder = LatentGlobalEncoder2D(**model_cfg)
    return encoder

from memKNO.encoder import SetEncoder2D
def build_set_encoder(model_cfg: SetEncoderParams):
    set_encoder = SetEncoder2D(**model_cfg)
    return set_encoder

from memKNO.encoder import EncoderRebuttal
def build_encoder_rebuttal(model_cfg: EncoderRebuParams):
    encoder = EncoderRebuttal(**model_cfg)
    return encoder

from memKNO.decoder import FieldDecoder, CrossFormerDecoder2D, FourierDecoder
def build_field_decoder(x_grid: torch.Tensor, model_cfg: FieldDecoderParams):
    field_decoder = FieldDecoder(x_grid=x_grid, **model_cfg)
    return field_decoder

def build_cross_decoder(x_grid: torch.Tensor, model_cfg: CrossFormerDecoder2D):
    cross_decoder = CrossFormerDecoder2D(x_grid=x_grid, **model_cfg)
    return cross_decoder

def build_fourier_decoder(model_cfg: FourierDecoderParams):
    fourier_decoder = FourierDecoder(**model_cfg)
    return fourier_decoder

from memKNO.latent import LatentProcess
def build_latent_process(model_cfg: LatentProcessParams):
    latent_process = LatentProcess(**model_cfg)
    return latent_process

from memKNO.latent_discrete import LatentProcessDiscrete
def build_latent_process_discrete(model_cfg: LatentProcessDiscreteParams):
    latent_process = LatentProcessDiscrete(**model_cfg)
    return latent_process