"""
xLSTM (mLSTM) layer wrapper for MAD-lab framework.

This module provides a wrapper around the mLSTM layer from the NX-AI xlstm package.
mLSTM (matrix LSTM) uses a matrix memory with multi-head attention-like structure
and exponential gating for improved performance on sequence modeling tasks.

Reference: https://github.com/NX-AI/xlstm
"""
import torch
import torch.nn as nn
from dataclasses import dataclass, field


# Try to import the xlstm package
try:
    from xlstm.blocks.mlstm.layer import mLSTMLayer, mLSTMLayerConfig
    XLSTM_AVAILABLE = True
except ImportError:
    XLSTM_AVAILABLE = False
    mLSTMLayer = None
    mLSTMLayerConfig = None


class mLSTMWrapper(nn.Module):
    """
    Wrapper for the mLSTM layer from the NX-AI xlstm package.

    mLSTM (matrix LSTM) uses:
    - Matrix memory with multi-head structure
    - Exponential gating for stability
    - Causal convolution for local context
    - Query-Key-Value projections similar to attention

    Args:
        dim (int): Model working dimension (input/output dimension)
        max_length (int): Maximum sequence length (context_length for xlstm)
        num_heads (int): Number of attention-like heads
        conv1d_kernel_size (int): Kernel size for causal convolution
        proj_factor (float): Expansion factor for up-projection (inner_dim = dim * proj_factor)
        qkv_proj_blocksize (int): Block size for QKV projections
        dropout (float): Dropout rate
        bias (bool): Whether to use bias in linear layers
        round_proj_up_to_multiple_of (int): Round projected dimension to multiple of this
        *args, **kwargs: Additional arguments (ignored, for MAD-lab compatibility)

    Input shape: (B, L, dim)
    Output shape: (B, L, dim)
    """

    def __init__(
        self,
        dim: int,
        max_length: int = 1024,
        num_heads: int = 4,
        conv1d_kernel_size: int = 4,
        proj_factor: float = 2.0,
        qkv_proj_blocksize: int = 4,
        dropout: float = 0.0,
        bias: bool = False,
        round_proj_up_to_multiple_of: int = 64,
        *args,
        **kwargs
    ):
        super().__init__()

        if not XLSTM_AVAILABLE:
            raise ImportError(
                "xlstm package not found. Please install it with:\n"
                "  pip install mlstm_kernels xlstm\n"
                "For more info: https://github.com/NX-AI/xlstm"
            )

        self.dim = dim
        self.max_length = max_length
        self.num_heads = num_heads

        # Validate that dim is divisible by num_heads for headwise operations
        if dim % num_heads != 0:
            raise ValueError(
                f"dim ({dim}) must be divisible by num_heads ({num_heads})"
            )

        # Create mLSTMLayerConfig
        # Note: mLSTMLayerConfig inherits from UpProjConfigMixin which has proj_factor
        config = mLSTMLayerConfig(
            embedding_dim=dim,
            num_heads=num_heads,
            context_length=max_length,
            conv1d_kernel_size=conv1d_kernel_size,
            proj_factor=proj_factor,
            qkv_proj_blocksize=qkv_proj_blocksize,
            dropout=dropout,
            bias=bias,
            round_proj_up_to_multiple_of=round_proj_up_to_multiple_of,
        )

        # Instantiate the mLSTM layer
        self.mlstm = mLSTMLayer(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the mLSTM layer.

        Args:
            x: Input tensor of shape (B, L, dim)

        Returns:
            Output tensor of shape (B, L, dim)
        """
        return self.mlstm(x)

    def step(self, x: torch.Tensor, mlstm_state=None, conv_state=None):
        """
        Single-step forward for autoregressive generation.

        Args:
            x: Input tensor of shape (B, 1, dim)
            mlstm_state: Previous mLSTM state (optional)
            conv_state: Previous convolution state (optional)

        Returns:
            Tuple of (output, state_dict) where state_dict contains
            'mlstm_state' and 'conv_state' for the next step.
        """
        return self.mlstm.step(x, mlstm_state=mlstm_state, conv_state=conv_state)


if __name__ == '__main__':
    # Test the wrapper
    if not XLSTM_AVAILABLE:
        print("xlstm package not available, skipping tests")
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        dtype = torch.bfloat16 if device == 'cuda' else torch.float32

        # Create test input
        B, L, dim = 2, 128, 128
        x = torch.rand(B, L, dim, dtype=dtype, device=device)

        # Test single-head mLSTM
        print("Testing single-head mLSTM...")
        mlstm_1h = mLSTMWrapper(dim=dim, num_heads=1, max_length=L).to(device).to(dtype)
        y_1h = mlstm_1h(x)
        assert x.shape == y_1h.shape, f"Shape mismatch: {x.shape} vs {y_1h.shape}"
        print(f"  Input shape: {x.shape}, Output shape: {y_1h.shape}")

        # Test multi-head mLSTM
        print("Testing multi-head mLSTM...")
        mlstm_4h = mLSTMWrapper(dim=dim, num_heads=4, max_length=L).to(device).to(dtype)
        y_4h = mlstm_4h(x)
        assert x.shape == y_4h.shape, f"Shape mismatch: {x.shape} vs {y_4h.shape}"
        print(f"  Input shape: {x.shape}, Output shape: {y_4h.shape}")

        print("All tests passed!")
