"""Minimal wrappers around Mamba variants for time-series experiments."""

from __future__ import annotations

from typing import Literal

import torch
import torch.nn as nn


class MambaSSMRegressor(nn.Module):
    """
    Sequence regressor using implementations from mamba_experiments.
    """

    def __init__(
        self,
        block_kind: Literal["simple", "diag_exp"],
        d_model: int = 64,
        d_state: int = 1,
        n_layers: int = 1,
    ):
        super().__init__()
        if block_kind not in {"simple", "diag_exp"}:
            raise ValueError("block_kind must be 'simple' or 'diag_exp'.")

        self.block_kind = block_kind
        self.input_proj = nn.Linear(1, d_model)

        if block_kind == "simple":
            from mamba_experiments.mamba_simple import ModelArgs, ResidualBlock, RMSNorm

            args = ModelArgs(d_model=d_model, n_layer=n_layers, vocab_size=1, d_state=d_state)
            self.blocks = nn.ModuleList([ResidualBlock(args) for _ in range(n_layers)])
            self.norm = RMSNorm(d_model)
        else:
            from mamba_experiments.mamba_SSD import ModelArgs, RMSNorm
            from mamba_experiments.mamba_SSD_diag_exp import SSDResidualBlockExp

            args = ModelArgs(d_model=d_model, n_layer=n_layers, vocab_size=1, d_state=d_state)
            self.blocks = nn.ModuleList([SSDResidualBlockExp(args) for _ in range(n_layers)])
            self.norm = RMSNorm(d_model)

        self.head = nn.Linear(d_model, 1)
        self.double()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, time, 1) or (batch, time)
        Returns:
            y_hat: (batch, time, 1) predictions
        """
        if x.dim() == 2:
            x = x.unsqueeze(-1)
        x = x.to(dtype=torch.float64)

        h = self.input_proj(x)
        for block in self.blocks:
            h = block(h)
        h = self.norm(h)
        y_hat = self.head(h)
        return y_hat
