"""
Wrapper to use the lightweight mamba.py implementation in this repo.

Requires the root-level mamba.py (single-file minimal Mamba) and its deps
(`einops`).
"""

from __future__ import annotations

import sys
from pathlib import Path

import torch
import torch.nn as nn


def _import_simple_mamba():
    try:
        from mamba_experiments import mamba_simple as mamba  # type: ignore
    except Exception as e:  # pragma: no cover
        raise ImportError(
            "Failed to import mamba_experiments/mamba_simple.py; ensure it exists and einops is installed."
        ) from e
    return mamba


class SimpleMambaModel(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 64, n_layers: int = 2, d_state: int = 16):
        super().__init__()
        mamba = _import_simple_mamba()
        args = mamba.ModelArgs(
            d_model=d_model,
            n_layer=n_layers,
            vocab_size=vocab_size,
            d_state=d_state,
        )
        self.model = mamba.Mamba(args)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.model(x)
        return logits[:, -1, :]  # use last token for classification
