#!/usr/bin/env python3

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.


import torch
import torch.nn as nn
from timm.models.registry import register_model
import math
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
from timm.models._builder import resolve_pretrained_cfg
try:
    from timm.models._builder import _update_default_kwargs as update_args
except:
    from timm.models._builder import _update_default_model_kwargs as update_args
from timm.models.vision_transformer import Mlp, PatchEmbed
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_fn_orth, selective_scan_online_orth_fn, selective_scan_online7_fn
from einops import rearrange, repeat
from .registry import register_pip_model
from pathlib import Path


def _cfg(url='', **kwargs):
    return {'url': url,
            'num_classes': 1000,
            'input_size': (3, 224, 224),
            'pool_size': None,
            'crop_pct': 0.875,
            'interpolation': 'bicubic',
            'fixed_input_size': True,
            'mean': (0.485, 0.456, 0.406),
            'std': (0.229, 0.224, 0.225),
            **kwargs
            }


default_cfgs = {
    'mamba_vision_T': _cfg(url='https://huggingface.co/nvidia/MambaVision-T-1K/resolve/main/mambavision_tiny_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_T2': _cfg(url='https://huggingface.co/nvidia/MambaVision-T2-1K/resolve/main/mambavision_tiny2_1k.pth.tar',
                            crop_pct=0.98,
                            input_size=(3, 224, 224),
                            crop_mode='center'),
    'mamba_vision_S': _cfg(url='https://huggingface.co/nvidia/MambaVision-S-1K/resolve/main/mambavision_small_1k.pth.tar',
                           crop_pct=0.93,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_B': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-1K/resolve/main/mambavision_base_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_B_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-B-21K/resolve/main/mambavision_base_21k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_L': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-1K/resolve/main/mambavision_large_1k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_L_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L-21K/resolve/main/mambavision_large_21k.pth.tar',
                           crop_pct=1.0,
                           input_size=(3, 224, 224),
                           crop_mode='center'),
    'mamba_vision_L2': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-1K/resolve/main/mambavision_large2_1k.pth.tar',
                            crop_pct=1.0,
                            input_size=(3, 224, 224),
                            crop_mode='center'),
    'mamba_vision_L2_512_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L2-512-21K/resolve/main/mambavision_L2_21k_240m_512.pth.tar',
                            crop_pct=0.93,
                            input_size=(3, 512, 512),
                            crop_mode='squash'),
    'mamba_vision_L3_256_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L3-256-21K/resolve/main/mambavision_L3_21k_740m_256.pth.tar',
                            crop_pct=1.0,
                            input_size=(3, 256, 256),
                            crop_mode='center'),
    'mamba_vision_L3_512_21k': _cfg(url='https://huggingface.co/nvidia/MambaVision-L3-512-21K/resolve/main/mambavision_L3_21k_740m_512.pth.tar',
                            crop_pct=0.93,
                            input_size=(3, 512, 512),
                            crop_mode='squash'),                               
}


def window_partition(x, window_size):
    """
    Args:
        x: (B, C, H, W)
        window_size: window size
        h_w: Height of window
        w_w: Width of window
    Returns:
        local window features (num_windows*B, window_size*window_size, C)
    """
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: local window features (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of image
        W: Width of image
    Returns:
        x: (B, C, H, W)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.reshape(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], H, W)
    return x


def _load_state_dict(module, state_dict, strict=False, logger=None):
    """Load state_dict to a module.

    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
    Default value for ``strict`` is set to ``False`` and the message for
    param mismatch will be shown even if strict is False.

    Args:
        module (Module): Module that receives the state_dict.
        state_dict (OrderedDict): Weights.
        strict (bool): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
        logger (:obj:`logging.Logger`, optional): Logger to log the error
            message. If not specified, print function will be used.
    """
    unexpected_keys = []
    all_missing_keys = []
    err_msg = []

    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata
    
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
                                     all_missing_keys, unexpected_keys,
                                     err_msg)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(module)
    load = None
    missing_keys = [
        key for key in all_missing_keys if 'num_batches_tracked' not in key
    ]

    if unexpected_keys:
        err_msg.append('unexpected key in source '
                       f'state_dict: {", ".join(unexpected_keys)}\n')
    if missing_keys:
        err_msg.append(
            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')

    
    if len(err_msg) > 0:
        err_msg.insert(
            0, 'The model and loaded state dict do not match exactly\n')
        err_msg = '\n'.join(err_msg)
        if strict:
            raise RuntimeError(err_msg)
        elif logger is not None:
            logger.warning(err_msg)
        else:
            print(err_msg)


def _load_checkpoint(model,
                    filename,
                    map_location='cpu',
                    strict=False,
                    logger=None):
    """Load checkpoint from a file or URI.

    Args:
        model (Module): Module to load checkpoint.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str): Same as :func:`torch.load`.
        strict (bool): Whether to allow different params for the model and
            checkpoint.
        logger (:mod:`logging.Logger` or None): The logger for error message.

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = torch.load(filename, map_location=map_location)
    if not isinstance(checkpoint, dict):
        raise RuntimeError(
            f'No state_dict found in checkpoint file {filename}')
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}

    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}

    _load_state_dict(model, state_dict, strict, logger)
    return checkpoint


class Downsample(nn.Module):
    """
    Down-sampling block"
    """

    def __init__(self,
                 dim,
                 keep_dim=False,
                 ):
        """
        Args:
            dim: feature size dimension.
            norm_layer: normalization layer.
            keep_dim: bool argument for maintaining the resolution.
        """

        super().__init__()
        if keep_dim:
            dim_out = dim
        else:
            dim_out = 2 * dim
        self.reduction = nn.Sequential(
            nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False),
        )

    def forward(self, x):
        x = self.reduction(x)
        return x


class PatchEmbed(nn.Module):
    """
    Patch embedding block"
    """

    def __init__(self, in_chans=3, in_dim=64, dim=96):
        """
        Args:
            in_chans: number of input channels.
            dim: feature size dimension.
        """
        # in_dim = 1
        super().__init__()
        self.proj = nn.Identity()
        self.conv_down = nn.Sequential(
            nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(in_dim, eps=1e-4),
            nn.ReLU(),
            nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(dim, eps=1e-4),
            nn.ReLU()
            )

    def forward(self, x):
        x = self.proj(x)
        x = self.conv_down(x)
        return x


class ConvBlock(nn.Module):

    def __init__(self, dim,
                 drop_path=0.,
                 layer_scale=None,
                 kernel_size=3):
        super().__init__()

        self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
        self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
        self.act1 = nn.GELU(approximate= 'tanh')
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
        self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
        self.layer_scale = layer_scale
        if layer_scale is not None and type(layer_scale) in [int, float]:
            self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
            self.layer_scale = True
        else:
            self.layer_scale = False
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        if self.layer_scale:
            x = x * self.gamma.view(1, -1, 1, 1)
        x = input + self.drop_path(x)
        return x


class MambaVisionMixer(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True, 
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)    
        self.x_proj = nn.Linear(
            self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner//2, bias=True, **factory_kwargs)
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError
        dt = torch.exp(
            torch.rand(self.d_inner//2, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner//2,
        ).contiguous()
        A_log = torch.log(A)
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True
        self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device))
        self.D._no_weight_decay = True
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.conv1d_x = nn.Conv1d(
            in_channels=self.d_inner//2,
            out_channels=self.d_inner//2,
            bias=conv_bias//2,
            kernel_size=d_conv,
            groups=self.d_inner//2,
            **factory_kwargs,
        )
        self.conv1d_z = nn.Conv1d(
            in_channels=self.d_inner//2,
            out_channels=self.d_inner//2,
            bias=conv_bias//2,
            kernel_size=d_conv,
            groups=self.d_inner//2,
            **factory_kwargs,
        )

    def forward(self, hidden_states):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        _, seqlen, _ = hidden_states.shape
        xz = self.in_proj(hidden_states)
        xz = rearrange(xz, "b l d -> b d l")
        x, z = xz.chunk(2, dim=1)
        A = -torch.exp(self.A_log.float())
        x = F.silu(F.conv1d(input=x, weight=self.conv1d_x.weight, bias=self.conv1d_x.bias, padding='same', groups=self.d_inner//2))
        z = F.silu(F.conv1d(input=z, weight=self.conv1d_z.weight, bias=self.conv1d_z.bias, padding='same', groups=self.d_inner//2))
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        y = selective_scan_fn(x, 
                              dt, 
                              A, 
                              B, 
                              C, 
                              self.D.float(), 
                              z=None, 
                              delta_bias=self.dt_proj.bias.float(), 
                              delta_softplus=True, 
                              return_last_state=None)
        
        y = torch.cat([y, z], dim=1)
        y = rearrange(y, "b d l -> b l d")
        out = self.out_proj(y)
        return out
    


class MuonLonghornVisionMixer(nn.Module):
    """
    MuonLonghorn Vision Mixer: Longhorn SSM with momentum and Newton-Schulz for vision tasks.
    
    This module keeps the MambaVisionMixer architecture (split x/z, dual conv1d, concatenation)
    but replaces the traditional Mamba SSM core with MuonLonghorn's SSM:
    - Uses Q, K projections instead of B, C (attention-like SSM)
    - Adds momentum velocity state: v_t = β * v_{t-1} + α * input
    - Optionally applies Newton-Schulz orthogonalization for stability
    
    Architecture:
        Input: (B, L, D)
        └── in_proj → (B, L, d_inner)
            └── split into x, z each (B, L, d_inner//2)
                ├── x: conv1d_x → SiLU → MuonLonghorn SSM → y
                └── z: conv1d_z → SiLU (gating branch)
            └── concat(y, z) → out_proj → (B, L, D)
    
    Args:
        d_model: Model dimension
        d_state: SSM state dimension (default: 16)
        d_conv: Convolution kernel size (default: 4)
        expand: Expansion factor for inner dimension (default: 2)
        dt_rank: Rank of delta projection, 'auto' = ceil(d_model / 16)
        dt_min: Minimum delta value for initialization
        dt_max: Maximum delta value for initialization
        dt_init: Delta initialization mode ('random' or 'constant')
        dt_scale: Delta scale factor
        dt_init_floor: Floor for delta initialization
        conv_bias: Whether to use bias in convolution
        bias: Whether to use bias in linear layers
        use_fast_path: Whether to use fused kernel (placeholder)
        layer_idx: Layer index for caching
        
        # MuonLonghorn-specific parameters:
        beta: Velocity decay factor (momentum), β ∈ [0, 1] (default: 0.9)
              - β = 0: No momentum (equivalent to vanilla Longhorn)
              - β = 0.9: High momentum, smooth state updates
        alpha: Velocity scale factor, α > 0 (default: 1.0)
        use_newton_schulz: Whether to apply Newton-Schulz orthogonalization (default: True)
        ns_steps: Number of Newton-Schulz iterations (default: 1)
        ns_mode: Newton-Schulz mode - 'compile' or 'triton' (default: 'compile')
    
    Example:
        >>> mixer = MuonLonghornVisionMixer(
        ...     d_model=256,
        ...     d_state=16,
        ...     beta=0.9,
        ...     alpha=1.0,
        ...     use_newton_schulz=True,
        ...     device='cuda',
        ... )
        >>> x = torch.randn(2, 196, 256, device='cuda')  # (batch, seq_len, dim)
        >>> y = mixer(x)  # (2, 196, 256)
    """
    
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True, 
        layer_idx=None,
        # MuonLonghorn-specific parameters
        # Note: Start with conservative defaults for stability with mixed precision
        beta=0.9,  # No momentum by default (set to 0.9 for full momentum)
        alpha=0.6,
        use_newton_schulz=True,  # Disable NS by default for stability
        ns_steps=1,
        ns_mode='compile',
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        
        # Core dimensions
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.d_ssm = self.d_inner // 2  # SSM operates on half the inner dimension
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        
        # MuonLonghorn-specific: momentum and Newton-Schulz parameters
        self.beta = beta
        self.alpha = alpha
        self.use_newton_schulz = use_newton_schulz
        self.ns_steps = ns_steps
        self.ns_mode = ns_mode
        
        # Input projection: d_model -> d_inner (split into x and z)
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
        
        # MuonLonghorn SSM: x_proj outputs dt_rank + 2*d_state (for dt, K, Q)
        # Note: Longhorn uses K, Q instead of B, C
        self.x_proj = nn.Linear(
            self.d_ssm, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        
        # Delta (time step) projection: dt_rank -> d_ssm
        self.dt_proj = nn.Linear(self.dt_rank, self.d_ssm, bias=True, **factory_kwargs)
        
        # Initialize dt_proj weights
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError(f"dt_init must be 'constant' or 'random', got {dt_init}")
        
        # Initialize dt_proj bias (similar to Longhorn's initialization)
        dt = torch.exp(
            torch.rand(self.d_ssm, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True
        
        # D "skip" parameter (Longhorn-style)
        self.D = nn.Parameter(torch.ones(self.d_ssm, device=device))
        self.D._no_weight_decay = True
        
        # Output projection: d_inner -> d_model
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        
        # Dual convolutions (MambaVision-style: separate conv for x and z branches)
        self.conv1d_x = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            padding=d_conv - 1,  # causal padding
            **factory_kwargs,
        )
        self.conv1d_z = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            padding=d_conv - 1,  # causal padding
            **factory_kwargs,
        )

    def forward(self, hidden_states, inference_params=None):
        """
        Forward pass through MuonLonghornVisionMixer.
        
        Args:
            hidden_states: (B, L, D) input tensor
            inference_params: Optional inference parameters (for generation)
        
        Returns:
            (B, L, D) output tensor
        """
        # Store input dtype for mixed precision compatibility
        input_dtype = hidden_states.dtype
        batch, seqlen, _ = hidden_states.shape
        
        # Input projection and split into x (SSM branch) and z (gating branch)
        xz = self.in_proj(hidden_states)
        xz = rearrange(xz, "b l d -> b d l").contiguous()
        x, z = xz.chunk(2, dim=1)  # Each: (B, d_ssm, L)
        
        # Apply convolutions with SiLU activation
        # Using 'same' padding to maintain sequence length
        x = F.silu(F.conv1d(
            input=x.contiguous(), 
            weight=self.conv1d_x.weight, 
            bias=self.conv1d_x.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        z = F.silu(F.conv1d(
            input=z.contiguous(), 
            weight=self.conv1d_z.weight, 
            bias=self.conv1d_z.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        
        # For SSM operations, we need float32 for numerical stability
        # CUDA kernel requires all inputs to have the same dtype
        x_float = x.float().contiguous()
        
        # Project x to get dt, K, Q (Longhorn-style: K and Q instead of B and C)
        x_dbl = self.x_proj(rearrange(x_float, "b d l -> (b l) d"))  # (B*L, dt_rank + 2*d_state)
        dt, k, q = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        
        # Project dt through dt_proj
        dt = self.dt_proj.weight @ dt.t()  # (d_ssm, B*L)
        dt = rearrange(dt, "d (b l) -> b d l", l=seqlen).contiguous()
        
        # Rearrange K and Q for SSM
        k = rearrange(k, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        q = rearrange(q, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        
        # CRITICAL: Ensure all tensors have same dtype (float32) for CUDA kernel
        # The CUDA kernel requires: u.dtype == Q.dtype == K.dtype == T.dtype
        dt = dt.float().contiguous()
        k = k.float().contiguous()
        q = q.float().contiguous()
        
        # MuonLonghorn SSM: select appropriate scan function based on Newton-Schulz setting
        # SSM operations in float32 for numerical stability
        if self.use_newton_schulz:
            y = selective_scan_online_orth_fn(
                x_float,
                q,
                k,
                dt,
                D=self.D.float(),
                t_bias=self.dt_proj.bias.float(),
                z=None,  # z is concatenated later (MambaVision-style)
                return_last_state=False,
                beta=self.beta,
                alpha=self.alpha,
                ns_steps=self.ns_steps,
                ns_mode=self.ns_mode,
            )
        else:
            y = selective_scan_online7_fn(
                x_float,
                q,
                k,
                dt,
                D=self.D.float(),
                t_bias=self.dt_proj.bias.float(),
                z=None,  # z is concatenated later (MambaVision-style)
                return_last_state=False,
                beta=self.beta,
                alpha=self.alpha,
            )
        
        # Convert back to input dtype for mixed precision compatibility
        y = y.to(dtype=input_dtype).contiguous()
        z = z.to(dtype=input_dtype).contiguous()
        
        # MambaVision-style: concatenate SSM output with gating branch
        y = torch.cat([y, z], dim=1)  # (B, d_inner, L)
        y = rearrange(y, "b d l -> b l d").contiguous()
        
        # Output projection
        out = self.out_proj(y)
        
        # Ensure output is contiguous and has correct dtype for downstream operations
        return out.contiguous()
    
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        """
        Allocate inference cache for incremental generation.
        
        Returns:
            (conv_state, ssm_state, velocity_state) tuple
        """
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d_x.weight.dtype if dtype is None else dtype
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        
        conv_state_x = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        conv_state_z = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_state = torch.zeros(
            batch_size, self.d_ssm, self.d_state, device=device, dtype=ssm_dtype
        )
        # MuonLonghorn: velocity state for momentum
        velocity_state = torch.zeros(
            batch_size, self.d_ssm, self.d_state, device=device, dtype=ssm_dtype
        )
        
        return (conv_state_x, conv_state_z), ssm_state, velocity_state


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange
from fla.modules import FusedRMSNormSwishGate, RMSNorm
from fla.modules.activations import ACT2FN
from .gated_delta_rule_ops import chunk_gated_delta_rule
try:
    from fla.modules.l2norm import l2_norm as l2_norm_fn
except ImportError:
    from fla.modules.l2norm import l2_norm_fn



class MuonGatedDeltaNetVisionMixer(nn.Module):
    """
    MuonGatedDeltaNet Vision Mixer: Gated Delta Net SSM with momentum for vision tasks.
    
    This module keeps the MambaVisionMixer architecture (split x/z, dual conv1d, concatenation)
    but replaces the traditional Mamba SSM core with Gated Delta Net's SSM:
    - Uses Q, K, V projections with multi-head attention-like structure
    - Applies gated delta rule: S_t = S_{t-1} * exp(gk_t) + beta_t * k_t * v_t^T
    - Supports different QK normalization: 'l2', 'longhorn', 'softmax'
    - Uses Mamba-style gating for the gate key (gk)
    - Adds momentum parameters for potential future momentum-enhanced delta rule
    
    Architecture:
        Input: (B, L, D)
        └── in_proj → (B, L, d_inner)
            └── split into x, z each (B, L, d_inner//2)
                ├── x: conv1d_x → SiLU → Gated Delta Net SSM → y
                └── z: conv1d_z → SiLU (gating branch)
            └── concat(y, z) → out_proj → (B, L, D)
    
    Args:
        d_model: Model dimension
        d_state: SSM state dimension (used for head_qk_dim calculation, default: 16)
        d_conv: Convolution kernel size (default: 4)
        expand: Expansion factor for inner dimension (default: 2)
        expand_k: Expansion factor for key dimension (default: 0.75)
        expand_v: Expansion factor for value dimension (default: 1.5)
        num_heads: Number of attention heads (default: 4)
        qk_norm: QK normalization mode - 'l2', 'longhorn', or 'softmax' (default: 'l2')
        gate_fn: Gate activation function (default: 'swish')
        gate_logit_normalizer: Normalizer for gate logits (default: 16)
        use_mamba_gate: Whether to use Mamba-style gating (default: True)
        conv_bias: Whether to use bias in convolution
        bias: Whether to use bias in linear layers
        layer_idx: Layer index for caching
        
        # Momentum parameters (for future momentum-enhanced delta rule):
        momentum_beta: Velocity decay factor (momentum), β ∈ [0, 1] (default: 0.9)
        momentum_alpha: Velocity scale factor, α > 0 (default: 1.0)
    
    Example:
        >>> mixer = MuonGatedDeltaNetVisionMixer(
        ...     d_model=256,
        ...     num_heads=4,
        ...     qk_norm='l2',
        ...     use_mamba_gate=True,
        ...     device='cuda',
        ... )
        >>> x = torch.randn(2, 196, 256, device='cuda')  # (batch, seq_len, dim)
        >>> y = mixer(x)  # (2, 196, 256)
    """
    
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        expand_k=0.75,
        expand_v=1.5,
        num_heads=4,
        qk_norm='l2',
        gate_fn='swish',
        gate_logit_normalizer=16,
        use_mamba_gate=True,
        conv_bias=True,
        bias=False,
        use_fast_path=True,
        layer_idx=None,
        # Momentum parameters
        momentum_beta=0.9,
        momentum_alpha=1.0,
        # Additional Gated Delta Net parameters
        elementwise_affine=True,
        norm_eps=1e-5,
        fuse_norm=True,
        use_residual=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        
        # Validate qk_norm
        assert qk_norm in ['l2', 'longhorn', 'softmax'], f"qk_norm must be 'l2', 'longhorn', or 'softmax', got {qk_norm}"
        
        # Core dimensions
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.d_ssm = self.d_inner // 2  # SSM operates on half the inner dimension
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        
        # Gated Delta Net specific dimensions
        self.expand_k = expand_k
        self.expand_v = expand_v
        self.num_heads = num_heads
        self.qk_norm = qk_norm
        self.gate_logit_normalizer = gate_logit_normalizer
        self.use_mamba_gate = use_mamba_gate
        self.use_residual = use_residual
        
        # Key and value dimensions based on d_ssm (the SSM branch dimension)
        self.key_dim = int(self.d_ssm * expand_k)
        self.value_dim = int(self.d_ssm * expand_v)
        self.head_qk_dim = self.key_dim // num_heads
        self.head_v_dim = self.value_dim // num_heads
        
        # Momentum parameters (stored as buffers)
        self.register_buffer("momentum_beta", torch.tensor(momentum_beta, dtype=torch.float32, device=device))
        self.register_buffer("momentum_alpha", torch.tensor(momentum_alpha, dtype=torch.float32, device=device))
        
        # Input projection: d_model -> d_inner (split into x and z)
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
        
        # Gated Delta Net projections (Q, K, V, G from x branch)
        self.q_proj = nn.Linear(self.d_ssm, self.key_dim, bias=False, **factory_kwargs)
        self.k_proj = nn.Linear(self.d_ssm, self.key_dim, bias=False, **factory_kwargs)
        self.v_proj = nn.Linear(self.d_ssm, self.value_dim, bias=False, **factory_kwargs)
        self.g_proj = nn.Linear(self.d_ssm, self.value_dim, bias=False, **factory_kwargs)
        
        # Gate key projection (Mamba-style or logsigmoid)
        self.gk_proj = nn.Linear(self.d_ssm, self.num_heads, bias=not use_mamba_gate, **factory_kwargs)
        
        # Beta projection for delta rule
        self.b_proj = nn.Linear(self.d_ssm, self.num_heads, bias=True, **factory_kwargs)
        
        # Mamba-style gating parameters
        if use_mamba_gate:
            A = torch.empty(self.num_heads, dtype=torch.float32, device=device).uniform_(0, 16)
            A_log = torch.log(A)
            self.A_log = nn.Parameter(A_log)
            self.A_log._no_weight_decay = True
            
            self.D = nn.Parameter(torch.ones(self.num_heads, device=device))
            self.D._no_weight_decay = True
            
            dt_min = 0.001
            dt_max = 0.1
            dt_init_floor = 1e-4
            dt = torch.exp(
                torch.rand(self.num_heads, device=device) * (math.log(dt_max) - math.log(dt_min))
                + math.log(dt_min)
            ).clamp(min=dt_init_floor)
            inv_dt = dt + torch.log(-torch.expm1(-dt))
            self.dt_bias = nn.Parameter(inv_dt)
            self.dt_bias._no_weight_decay = True
        
        # Residual D parameter (if use_residual)
        if use_residual and not use_mamba_gate:
            self.D = nn.Parameter(torch.ones(self.num_heads, device=device))
            self.D._no_weight_decay = True
        
        # Output normalization and gating
        if gate_fn == 'swish' and fuse_norm:
            from fla.modules import FusedRMSNormSwishGate
            self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps)
            self.fuse_norm_and_gate = True
        else:
            from fla.modules import RMSNorm
            from fla.modules.activations import ACT2FN
            self.fuse_norm_and_gate = False
            self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
            self.gate_fn = ACT2FN[gate_fn]
        
        # Output projection from value_dim back to d_ssm
        self.ssm_out_proj = nn.Linear(self.value_dim, self.d_ssm, bias=False, **factory_kwargs)
        
        # Final output projection: d_inner -> d_model
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        
        # Dual convolutions (MambaVision-style: separate conv for x and z branches)
        self.conv1d_x = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            **factory_kwargs,
        )
        self.conv1d_z = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            **factory_kwargs,
        )

    def forward(self, hidden_states):
        """
        Forward pass through MuonGatedDeltaNetVisionMixer.
        
        Args:
            hidden_states: (B, L, D) input tensor
        
        Returns:
            (B, L, D) output tensor
        """
        from .gated_delta_rule_ops import chunk_gated_delta_rule
        try:
            from fla.modules.l2norm import l2_norm as l2_norm_fn
        except:
            from fla.modules.l2norm import l2_norm_fn
        
        batch, seqlen, _ = hidden_states.shape
        
        # Input projection and split into x (SSM branch) and z (gating branch)
        xz = self.in_proj(hidden_states)
        xz = rearrange(xz, "b l d -> b d l")
        x, z = xz.chunk(2, dim=1)  # Each: (B, d_ssm, L)
        
        # Apply convolutions with SiLU activation
        x = F.silu(F.conv1d(
            input=x, 
            weight=self.conv1d_x.weight, 
            bias=self.conv1d_x.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        z = F.silu(F.conv1d(
            input=z, 
            weight=self.conv1d_z.weight, 
            bias=self.conv1d_z.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        
        # Rearrange x for Gated Delta Net projections: (B, d_ssm, L) -> (B, L, d_ssm)
        x = rearrange(x, "b d l -> b l d")
        
        # Project to Q, K, V
        q = self.q_proj(x)  # (B, L, key_dim)
        k = self.k_proj(x)  # (B, L, key_dim)
        v = self.v_proj(x)  # (B, L, value_dim)
        
        # Gate key projection
        gk = self.gk_proj(x).float()  # (B, L, num_heads)
        if self.use_mamba_gate:
            gk = -self.A_log.float().exp() * F.softplus(gk + self.dt_bias)
        else:
            gk = F.logsigmoid(gk) / self.gate_logit_normalizer
        gk = gk.transpose(1, 2)  # (B, num_heads, L)
        
        # Beta projection for delta rule
        beta = self.b_proj(x).float().sigmoid()  # (B, L, num_heads)
        beta = beta.transpose(1, 2)  # (B, num_heads, L)
        
        # Rearrange Q, K, V for multi-head attention
        q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads)  # (B, H, L, head_qk_dim)
        k = rearrange(k, 'b l (h d) -> b h l d', h=self.num_heads)  # (B, H, L, head_qk_dim)
        v = rearrange(v, 'b l (h d) -> b h l d', h=self.num_heads)  # (B, H, L, head_v_dim)
        
        # Apply QK normalization
        if self.qk_norm == 'l2':
            q = l2_norm_fn(q).to(v)
            k = l2_norm_fn(k).to(v)
        elif self.qk_norm == 'softmax':
            q = q.softmax(dim=-1).to(v)
            k = k.softmax(dim=-1).to(v)
        elif self.qk_norm == 'longhorn':
            # Longhorn-style normalization adjusts beta based on k norm
            beta = beta / (1 + beta * (k * k).sum(-1))
        
        # Apply Gated Delta Rule SSM
        o, _ = chunk_gated_delta_rule(
            q, k, v, beta, gk,
            initial_state=None,
            output_final_state=False
        )  # o: (B, H, L, head_v_dim)
        
        # Add residual connection if enabled
        if self.use_residual:
            o = o + self.D[None, :, None, None] * v
        
        # Rearrange for output normalization
        o = rearrange(o, 'b h l d -> b l h d')  # (B, L, H, head_v_dim)
        
        # Apply output gating with normalization
        g = self.g_proj(x)  # (B, L, value_dim)
        if self.fuse_norm_and_gate:
            g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads)
            o = self.g_norm_swish_gate(o, g)
            o = rearrange(o, 'b l h d -> b l (h d)')
        else:
            o = rearrange(self.g_norm(o), 'b l h d -> b l (h d)')
            o = o * self.gate_fn(g)
        
        # Project back to d_ssm dimension
        y = self.ssm_out_proj(o)  # (B, L, d_ssm)
        
        # Rearrange y back to (B, d_ssm, L) for concatenation
        y = rearrange(y, "b l d -> b d l")
        
        # MambaVision-style: concatenate SSM output with gating branch
        y = torch.cat([y, z], dim=1)  # (B, d_inner, L)
        y = rearrange(y, "b d l -> b l d")
        
        # Output projection
        out = self.out_proj(y)
        return out

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        """
        Allocate inference cache for incremental generation.
        
        Returns:
            (conv_state, ssm_state) tuple
        """
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d_x.weight.dtype if dtype is None else dtype
        ssm_dtype = self.q_proj.weight.dtype if dtype is None else dtype
        
        conv_state_x = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        conv_state_z = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        # SSM state for Gated Delta Net: (B, num_heads, head_qk_dim, head_v_dim)
        ssm_state = torch.zeros(
            batch_size, self.num_heads, self.head_qk_dim, self.head_v_dim, 
            device=device, dtype=ssm_dtype
        )
        
        return (conv_state_x, conv_state_z), ssm_state



class MuonMambaVisionMixer(nn.Module):
    """
    MuonMamba Vision Mixer: Traditional Mamba SSM with momentum and Newton-Schulz for vision tasks.
    
    This module keeps the MambaVisionMixer architecture (split x/z, dual conv1d, concatenation)
    but adds MuonMamba's momentum and Newton-Schulz orthogonalization features:
    - Uses traditional A, B, C, D SSM parameters (like original Mamba)
    - Adds momentum velocity state: v_t = β * v_{t-1} + α * B_t * x_t
    - Optionally applies Newton-Schulz orthogonalization for stability
    
    Architecture:
        Input: (B, L, D)
        └── in_proj → (B, L, d_inner)
            └── split into x, z each (B, L, d_inner//2)
                ├── x: conv1d_x → SiLU → MuonMamba SSM (A,B,C,D + momentum) → y
                └── z: conv1d_z → SiLU (gating branch)
            └── concat(y, z) → out_proj → (B, L, D)
    
    Args:
        d_model: Model dimension
        d_state: SSM state dimension (default: 16)
        d_conv: Convolution kernel size (default: 4)
        expand: Expansion factor for inner dimension (default: 2)
        dt_rank: Rank of delta projection, 'auto' = ceil(d_model / 16)
        dt_min: Minimum delta value for initialization
        dt_max: Maximum delta value for initialization
        dt_init: Delta initialization mode ('random' or 'constant')
        dt_scale: Delta scale factor
        dt_init_floor: Floor for delta initialization
        conv_bias: Whether to use bias in convolution
        bias: Whether to use bias in linear layers
        use_fast_path: Whether to use fused kernel (only when beta=0 and not using NS)
        layer_idx: Layer index for caching
        
        # MuonMamba-specific parameters:
        beta: Velocity decay factor (momentum), β ∈ [0, 1] (default: 0.9)
              - β = 0: No momentum (equivalent to vanilla Mamba)
              - β = 0.9: High momentum, smooth state updates
        alpha: Velocity scale factor, α > 0 (default: 1.0)
        use_newton_schulz: Whether to apply Newton-Schulz orthogonalization (default: True)
        ns_steps: Number of Newton-Schulz iterations (default: 1)
    
    Example:
        >>> mixer = MuonMambaVisionMixer(
        ...     d_model=256,
        ...     d_state=16,
        ...     beta=0.9,
        ...     alpha=1.0,
        ...     use_newton_schulz=True,
        ...     device='cuda',
        ... )
        >>> x = torch.randn(2, 196, 256, device='cuda')  # (batch, seq_len, dim)
        >>> y = mixer(x)  # (2, 196, 256)
    """
    
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True, 
        layer_idx=None,
        # MuonMamba-specific parameters
        beta=0.9,
        alpha=1.0,
        use_newton_schulz=True,
        ns_steps=1,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        
        # Core dimensions
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.d_ssm = self.d_inner // 2  # SSM operates on half the inner dimension
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx
        
        # MuonMamba-specific: momentum and Newton-Schulz parameters
        self.use_newton_schulz = use_newton_schulz
        self.ns_steps = ns_steps
        
        # Register momentum parameters as buffers (not learnable)
        self.register_buffer("beta", torch.tensor(beta, dtype=torch.float32, device=device))
        self.register_buffer("alpha", torch.tensor(alpha, dtype=torch.float32, device=device))
        
        # Input projection: d_model -> d_inner (split into x and z)
        self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
        
        # SSM projections: x_proj outputs dt_rank + 2*d_state (for dt, B, C)
        self.x_proj = nn.Linear(
            self.d_ssm, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        
        # Delta (time step) projection: dt_rank -> d_ssm
        self.dt_proj = nn.Linear(self.dt_rank, self.d_ssm, bias=True, **factory_kwargs)
        
        # Initialize dt_proj weights
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError(f"dt_init must be 'constant' or 'random', got {dt_init}")
        
        # Initialize dt_proj bias
        dt = torch.exp(
            torch.rand(self.d_ssm, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        self.dt_proj.bias._no_reinit = True
        
        # A parameter (traditional Mamba-style: negative log of decay)
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_ssm,
        ).contiguous()
        A_log = torch.log(A)
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True
        
        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_ssm, device=device))
        self.D._no_weight_decay = True
        
        # Output projection: d_inner -> d_model
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        
        # Dual convolutions (MambaVision-style: separate conv for x and z branches)
        self.conv1d_x = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            **factory_kwargs,
        )
        self.conv1d_z = nn.Conv1d(
            in_channels=self.d_ssm,
            out_channels=self.d_ssm,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_ssm,
            **factory_kwargs,
        )

    def forward(self, hidden_states):
        """
        Forward pass through MuonMambaVisionMixer.
        
        Args:
            hidden_states: (B, L, D) input tensor
        
        Returns:
            (B, L, D) output tensor
        """
        batch, seqlen, _ = hidden_states.shape
        
        # Input projection and split into x (SSM branch) and z (gating branch)
        xz = self.in_proj(hidden_states)
        xz = rearrange(xz, "b l d -> b d l")
        x, z = xz.chunk(2, dim=1)  # Each: (B, d_ssm, L)
        
        # Compute A from A_log (negative exponential for stability)
        A = -torch.exp(self.A_log.float())
        
        # Apply convolutions with SiLU activation
        x = F.silu(F.conv1d(
            input=x, 
            weight=self.conv1d_x.weight, 
            bias=self.conv1d_x.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        z = F.silu(F.conv1d(
            input=z, 
            weight=self.conv1d_z.weight, 
            bias=self.conv1d_z.bias, 
            padding='same', 
            groups=self.d_ssm
        ))
        
        # Project x to get dt, B, C (traditional Mamba-style)
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (B*L, dt_rank + 2*d_state)
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        
        # Project dt through dt_proj
        dt = rearrange(self.dt_proj(dt), "(b l) d -> b d l", l=seqlen)
        
        # Rearrange B and C for SSM
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
        
        # MuonMamba SSM: select appropriate scan function based on Newton-Schulz setting
        if self.use_newton_schulz:
            # MuonMamba with Newton-Schulz orthogonalization
            y = selective_scan_fn_orth(
                x, 
                dt, 
                A, 
                B, 
                C, 
                self.D.float(), 
                z=None,  # z is concatenated later (MambaVision-style)
                delta_bias=self.dt_proj.bias.float(), 
                delta_softplus=True, 
                return_last_state=None,
                beta=self.beta,
                alpha=self.alpha,
                ns_steps=self.ns_steps,
            )
        else:
            # MuonMamba with momentum only (no Newton-Schulz)
            y = selective_scan_fn(
                x, 
                dt, 
                A, 
                B, 
                C, 
                self.D.float(), 
                z=None,  # z is concatenated later (MambaVision-style)
                delta_bias=self.dt_proj.bias.float(), 
                delta_softplus=True, 
                return_last_state=None,
                beta=self.beta,
                alpha=self.alpha,
            )
        
        # MambaVision-style: concatenate SSM output with gating branch
        y = torch.cat([y, z], dim=1)  # (B, d_inner, L)
        y = rearrange(y, "b d l -> b l d")
        
        # Output projection
        out = self.out_proj(y)
        return out
    
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        """
        Allocate inference cache for incremental generation.
        
        Returns:
            (conv_state, ssm_state, velocity_state) tuple
        """
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d_x.weight.dtype if dtype is None else dtype
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        
        conv_state_x = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        conv_state_z = torch.zeros(
            batch_size, self.d_ssm, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_state = torch.zeros(
            batch_size, self.d_ssm, self.d_state, device=device, dtype=ssm_dtype
        )
        # MuonMamba: velocity state for momentum
        velocity_state = torch.zeros(
            batch_size, self.d_ssm, self.d_state, device=device, dtype=ssm_dtype
        )
        
        return (conv_state_x, conv_state_z), ssm_state, velocity_state



class Attention(nn.Module):

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = True

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
             q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, 
                 dim, 
                 num_heads, 
                 counter, 
                 transformer_blocks, 
                 mlp_ratio=4., 
                 qkv_bias=False, 
                 qk_scale=False, 
                 drop=0., 
                 attn_drop=0.,
                 drop_path=0., 
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm, 
                 Mlp_block=Mlp,
                 layer_scale=None,
                 ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        if counter in transformer_blocks:
            self.mixer = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
            norm_layer=norm_layer,
        )
        else:
            self.mixer = MambaVisionMixer(d_model=dim, 
                                          d_state=8,  
                                          d_conv=3,    
                                          expand=1
                                          )
            # self.mixer = MuonMambaVisionMixer(d_model=dim, 
            #                               d_state=8,  
            #                               d_conv=3,    
            #                               expand=1
            #                               )
            # self.mixer = MuonLonghornVisionMixer(d_model=dim, 
            #                               d_state=8,  
            #                               d_conv=3,    
            #                               expand=1
            #                               )

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
        self.gamma_1 = nn.Parameter(layer_scale * torch.ones(dim))  if use_layer_scale else 1
        self.gamma_2 = nn.Parameter(layer_scale * torch.ones(dim))  if use_layer_scale else 1

    def forward(self, x):
        x = x + self.drop_path(self.gamma_1 * self.mixer(self.norm1(x)))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class MambaVisionLayer(nn.Module):
    """
    MambaVision layer"
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size,
                 conv=False,
                 downsample=True,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 layer_scale=None,
                 layer_scale_conv=None,
                 transformer_blocks = [],
    ):
        """
        Args:
            dim: feature size dimension.
            depth: number of layers in each stage.
            window_size: window size in each stage.
            conv: bool argument for conv stage flag.
            downsample: bool argument for down-sampling.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop: dropout rate.
            attn_drop: attention dropout rate.
            drop_path: drop path rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
            layer_scale_conv: conv layer scaling coefficient.
            transformer_blocks: list of transformer blocks.
        """

        super().__init__()
        self.conv = conv
        self.transformer_block = False
        if conv:
            self.blocks = nn.ModuleList([ConvBlock(dim=dim,
                                                   drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                                   layer_scale=layer_scale_conv)
                                                   for i in range(depth)])
            self.transformer_block = False
        else:
            self.blocks = nn.ModuleList([Block(dim=dim,
                                               counter=i, 
                                               transformer_blocks=transformer_blocks,
                                               num_heads=num_heads,
                                               mlp_ratio=mlp_ratio,
                                               qkv_bias=qkv_bias,
                                               qk_scale=qk_scale,
                                               drop=drop,
                                               attn_drop=attn_drop,
                                               drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                               layer_scale=layer_scale)
                                               for i in range(depth)])
            self.transformer_block = True

        self.downsample = None if not downsample else Downsample(dim=dim)
        self.do_gt = False
        self.window_size = window_size

    def forward(self, x):
        _, _, H, W = x.shape

        if self.transformer_block:
            pad_r = (self.window_size - W % self.window_size) % self.window_size
            pad_b = (self.window_size - H % self.window_size) % self.window_size
            if pad_r > 0 or pad_b > 0:
                x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b))
                _, _, Hp, Wp = x.shape
            else:
                Hp, Wp = H, W
            x = window_partition(x, self.window_size)

        for _, blk in enumerate(self.blocks):
            x = blk(x)
        if self.transformer_block:
            x = window_reverse(x, self.window_size, Hp, Wp)
            if pad_r > 0 or pad_b > 0:
                x = x[:, :, :H, :W].contiguous()
        if self.downsample is None:
            return x
        return self.downsample(x)


class MambaVision(nn.Module):
    """
    MambaVision,
    """

    def __init__(self,
                 dim,
                 in_dim,
                 depths,
                 window_size,
                 mlp_ratio,
                 num_heads,
                 drop_path_rate=0.2,
                 in_chans=3,
                 num_classes=1000,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 layer_scale=None,
                 layer_scale_conv=None,
                 **kwargs):
        """
        Args:
            dim: feature size dimension.
            depths: number of layers in each stage.
            window_size: window size in each stage.
            mlp_ratio: MLP ratio.
            num_heads: number of heads in each stage.
            drop_path_rate: drop path rate.
            in_chans: number of input channels.
            num_classes: number of classes.
            qkv_bias: bool argument for query, key, value learnable bias.
            qk_scale: bool argument to scaling query, key.
            drop_rate: dropout rate.
            attn_drop_rate: attention dropout rate.
            norm_layer: normalization layer.
            layer_scale: layer scaling coefficient.
            layer_scale_conv: conv layer scaling coefficient.
        """
        super().__init__()
        num_features = int(dim * 2 ** (len(depths) - 1))
        self.num_classes = num_classes
        self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        self.levels = nn.ModuleList()
        for i in range(len(depths)):
            conv = True if (i == 0 or i == 1) else False
            level = MambaVisionLayer(dim=int(dim * 2 ** i),
                                     depth=depths[i],
                                     num_heads=num_heads[i],
                                     window_size=window_size[i],
                                     mlp_ratio=mlp_ratio,
                                     qkv_bias=qkv_bias,
                                     qk_scale=qk_scale,
                                     conv=conv,
                                     drop=drop_rate,
                                     attn_drop=attn_drop_rate,
                                     drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
                                     downsample=(i < 3),
                                     layer_scale=layer_scale,
                                     layer_scale_conv=layer_scale_conv,
                                     transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])),
                                     )
            self.levels.append(level)
        self.norm = nn.BatchNorm2d(num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, LayerNorm2d):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'rpb'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        for level in self.levels:
            x = level(x)
        x = self.norm(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

    def _load_state_dict(self, 
                         pretrained, 
                         strict: bool = False):
        _load_checkpoint(self, 
                         pretrained, 
                         strict=strict)


@register_pip_model
@register_model
def mamba_vision_T(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T.pth.tar")
    depths = kwargs.pop("depths", [1, 3, 8, 4])
    num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 80)
    in_dim = kwargs.pop("in_dim", 32)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_T2(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_T2.pth.tar")
    depths = kwargs.pop("depths", [1, 3, 11, 4])
    num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 80)
    in_dim = kwargs.pop("in_dim", 32)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_T2').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_S(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_S.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 7, 5])
    num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 96)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.2)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_S').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_B(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 10, 5])
    num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 128)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_B_21k(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_B_21k.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 10, 5])
    num_heads = kwargs.pop("num_heads", [2, 4, 8, 16])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 128)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_B_21k').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 10, 5])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 196)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L_21k(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L_21k.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 10, 5])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 196)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L_21k').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L2(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 12, 5])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 14, 7])
    dim = kwargs.pop("dim", 196)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 224)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L2_512_21k(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L2_512_21k.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 12, 5])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 32, 16])
    dim = kwargs.pop("dim", 196)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 512)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.3)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L2_512_21k').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L3_256_21k(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L3_256_21k.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 20, 10])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 16, 8])
    dim = kwargs.pop("dim", 256)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 256)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.5)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L3_256_21k').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model


@register_pip_model
@register_model
def mamba_vision_L3_512_21k(pretrained=False, **kwargs):
    model_path = kwargs.pop("model_path", "/tmp/mamba_vision_L3_512_21k.pth.tar")
    depths = kwargs.pop("depths", [3, 3, 20, 10])
    num_heads = kwargs.pop("num_heads", [4, 8, 16, 32])
    window_size = kwargs.pop("window_size", [8, 8, 32, 16])
    dim = kwargs.pop("dim", 256)
    in_dim = kwargs.pop("in_dim", 64)
    mlp_ratio = kwargs.pop("mlp_ratio", 4)
    resolution = kwargs.pop("resolution", 512)
    drop_path_rate = kwargs.pop("drop_path_rate", 0.5)
    layer_scale = kwargs.pop("layer_scale", 1e-5)
    pretrained_cfg = resolve_pretrained_cfg('mamba_vision_L3_512_21k').to_dict()
    update_args(pretrained_cfg, kwargs, kwargs_filter=None)
    model = MambaVision(depths=depths,
                        num_heads=num_heads,
                        window_size=window_size,
                        dim=dim,
                        in_dim=in_dim,
                        mlp_ratio=mlp_ratio,
                        resolution=resolution,
                        drop_path_rate=drop_path_rate,
                        layer_scale=layer_scale,
                        layer_scale_conv=None,
                        **kwargs)
    model.pretrained_cfg = pretrained_cfg
    model.default_cfg = model.pretrained_cfg
    if pretrained:
        if not Path(model_path).is_file():
            url = model.default_cfg['url']
            torch.hub.download_url_to_file(url=url, dst=model_path)
        model._load_state_dict(model_path)
    return model
