import torch
from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

class XLSTMModel:
    def __init__(self, 
                 context_length=256,
                 embedding_dim=128,
                 num_blocks=7,
                 num_heads=4,
                 conv1d_kernel_size=4,
                 qkv_proj_blocksize=4,
                 proj_factor=1.3,
                 slstm_at=[1],
                 device="cuda"):
        self.device = device
        
        self.cfg = xLSTMBlockStackConfig(
            mlstm_block=mLSTMBlockConfig(
                mlstm=mLSTMLayerConfig(
                    conv1d_kernel_size=conv1d_kernel_size,
                    qkv_proj_blocksize=qkv_proj_blocksize,
                    num_heads=num_heads
                )
            ),
            slstm_block=sLSTMBlockConfig(
                slstm=sLSTMLayerConfig(
                    backend=device,
                    num_heads=num_heads,
                    conv1d_kernel_size=conv1d_kernel_size,
                    bias_init="powerlaw_blockdependent",
                ),
                feedforward=FeedForwardConfig(
                    proj_factor=proj_factor,
                    act_fn="gelu"
                ),
            ),
            context_length=context_length,
            num_blocks=num_blocks,
            embedding_dim=embedding_dim,
            slstm_at=slstm_at,
        )
        
        self.xlstm_stack = xLSTMBlockStack(self.cfg)
        self.xlstm_stack = self.xlstm_stack.to(self.device)
    
    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x)
        x = x.to(self.device)
        return self.xlstm_stack(x)
    
    def __call__(self, x):
        return self.forward(x)

if __name__ == "__main__":
    model = XLSTMModel(
        context_length=256,
        embedding_dim=128,
        num_blocks=7,
        num_heads=4,
        conv1d_kernel_size=4,
        qkv_proj_blocksize=4,
        proj_factor=1.3,
        slstm_at=[1],
        device="cuda"
    )
    
    x = torch.randn(4, 256, 128).to("cuda")
    y = model(x)
    print(y.shape)
    print('Output shape matches input shape:', y.shape == x.shape)
