"""Mamba layer."""

from mamba_ssm import Mamba2 as MambaSSM
from torch import nn

class MambaLayer(nn.Module):
    """Mamba layer for LM."""
    def __init__(self, d_model, layer_idx=None, device=None, dtype=None):
        """Initalize mamba layer."""
        super().__init__()
        self.inner_mamba = MambaSSM(
            d_model=d_model,
            d_state=16,
            d_conv=4,
            expand=2,
        )
        self.norm = nn.LayerNorm(d_model, device=device)
        self.device = device

    def forward(self, x, return_attention=False, input_ids=None):
        """Forward method for mamba layer."""
        assert return_attention is False
        return self.inner_mamba(x)


