import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.fft

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import math
import numpy as np
from models.simba_x.mamba_ssm import Mamba
from einops import rearrange, repeat, einsum


class EinFFT(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.hidden_size = dim  # 768
        self.num_blocks = 4
        self.block_size = self.hidden_size // self.num_blocks
        assert self.hidden_size % self.num_blocks == 0
        self.sparsity_threshold = 0.01
        self.scale = 0.02

        self.complex_weight_1 = nn.Parameter(
            torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_weight_2 = nn.Parameter(
            torch.randn(2, self.num_blocks, self.block_size, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_bias_1 = nn.Parameter(
            torch.randn(2, self.num_blocks, self.block_size, dtype=torch.float32) * self.scale)
        self.complex_bias_2 = nn.Parameter(
            torch.randn(2, self.num_blocks, self.block_size, dtype=torch.float32) * self.scale)

    def multiply(self, input, weights):
        return torch.einsum('...bd,bdk->...bk', input, weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.view(B, N, self.num_blocks, self.block_size)

        x = torch.fft.fft2(x, dim=(1, 2), norm='ortho')  # FFT on N dimension

        x_real_1 = F.relu(
            self.multiply(x.real, self.complex_weight_1[0]) - self.multiply(x.imag, self.complex_weight_1[1]) +
            self.complex_bias_1[0])
        x_imag_1 = F.relu(
            self.multiply(x.real, self.complex_weight_1[1]) + self.multiply(x.imag, self.complex_weight_1[0]) +
            self.complex_bias_1[1])
        x_real_2 = self.multiply(x_real_1, self.complex_weight_2[0]) - self.multiply(x_imag_1,
                                                                                     self.complex_weight_2[1]) + \
                   self.complex_bias_2[0]
        x_imag_2 = self.multiply(x_real_1, self.complex_weight_2[1]) + self.multiply(x_imag_1,
                                                                                     self.complex_weight_2[0]) + \
                   self.complex_bias_2[1]

        x = torch.stack([x_real_2, x_imag_2], dim=-1).float()
        x = F.softshrink(x, lambd=self.sparsity_threshold) if self.sparsity_threshold else x
        x = torch.view_as_complex(x)

        x = torch.fft.ifft2(x, dim=(1, 2), norm="ortho")

        # RuntimeError: "fused_dropout" not implemented for 'ComplexFloat'
        x = x.to(torch.float32)
        x = x.reshape(B, N, C)
        return x


# # For Fast Implementation use MambaLayer,# This implementation is slow, only for checking GFLOPS and other paramater,
# # For more details please refer to https://github.com/johnma2006/mamba-minimal/blob/master/model.py
# class MambaBlock(nn.Module):
#     def __init__(self, d_model, d_state=64, expand=2, d_conv=4, conv_bias=True, bias=False):
#         """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
#         super().__init__()
#         self.d_model = d_model  # Model dimension d_model
#         self.d_state = d_state  # SSM state expansion factor
#         self.d_conv = d_conv  # Local convolution width
#         self.expand = expand  # Block expansion factor
#         self.conv_bias = conv_bias
#         self.bias = bias
#         self.d_inner = int(self.expand * self.d_model)
#         self.dt_rank = math.ceil(self.d_model / 16)
#
#         self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=self.bias)
#
#         self.conv1d = nn.Conv1d(
#             in_channels=self.d_inner,
#             out_channels=self.d_inner,
#             bias=self.conv_bias,
#             kernel_size=self.d_conv,
#             groups=self.d_inner,
#             padding=self.d_conv - 1,
#         )
#
#         # x_proj takes in `x` and outputs the input-specific Δ, B, C
#         self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
#
#         # dt_proj projects Δ from dt_rank to d_in
#         self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
#
#         A = repeat(torch.arange(1, self.d_state + 1), 'n -> d n', d=self.d_inner)
#         self.A_log = nn.Parameter(torch.log(A))
#         self.D = nn.Parameter(torch.ones(self.d_inner))
#         self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=self.bias)
#
#     def forward(self, x):
#         """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
#
#         Args:
#             x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
#
#         Returns:
#             output: shape (b, l, d)
#
#         Official Implementation:
#             class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
#             mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
#
#         """
#         (b, l, d) = x.shape
#
#         x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
#         (x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)
#
#         x = rearrange(x, 'b l d_in -> b d_in l')
#         x = self.conv1d(x)[:, :, :l]
#         x = rearrange(x, 'b d_in l -> b l d_in')
#
#         x = F.silu(x)
#
#         y = self.ssm(x)
#
#         y = y * F.silu(res)
#
#         output = self.out_proj(y)
#
#         return output
#
#     def ssm(self, x):
#         """Runs the SSM. See:
#             - Algorithm 2 in Section 3.2 in the Mamba paper [1]
#             - run_SSM(A, B, C, u) in The Annotated S4 [2]
#
#         Args:
#             x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
#
#         Returns:
#             output: shape (b, l, d_in)
#
#         Official Implementation:
#             mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
#
#         """
#         (d_in, n) = self.A_log.shape
#
#         # Compute ∆ A B C D, the state space parameters.
#         #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
#         #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
#         #                                  and is why Mamba is called **selective** state spaces)
#
#         A = -torch.exp(self.A_log.float())  # shape (d_in, n)
#         D = self.D.float()
#
#         x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
#
#         (delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
#         delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
#
#         y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
#
#         return y
#
#     def selective_scan(self, u, delta, A, B, C, D):
#         """Does selective scan algorithm. See:
#             - Section 2 State Space Models in the Mamba paper [1]
#             - Algorithm 2 in Section 3.2 in the Mamba paper [1]
#             - run_SSM(A, B, C, u) in The Annotated S4 [2]
#
#         This is the classic discrete state space formula:
#             x(t + 1) = Ax(t) + Bu(t)
#             y(t)     = Cx(t) + Du(t)
#         except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
#
#         Args:
#             u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
#             delta: shape (b, l, d_in)
#             A: shape (d_in, n)
#             B: shape (b, l, n)
#             C: shape (b, l, n)
#             D: shape (d_in,)
#
#         Returns:
#             output: shape (b, l, d_in)
#
#         Official Implementation:
#             selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
#             Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
#
#         """
#         (b, l, d_in) = u.shape
#         n = A.shape[1]
#
#         # Discretize continuous parameters (A, B)
#         # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
#         # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
#         #   "A is the more important term and the performance doesn't change much with the simplification on B"
#         deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
#         deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
#
#         # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
#         # Note that the below is sequential, while the official implementation does a much faster parallel scan that
#         # is additionally hardware-aware (like FlashAttention).
#         x = torch.zeros((b, d_in, n), device=deltaA.device)
#         ys = []
#         for i in range(l):
#             x = deltaA[:, i] * x + deltaB_u[:, i]
#             y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
#             ys.append(y)
#         y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
#
#         y = y + u * D
#
#         return y


class MambaLayer(nn.Module):
    def __init__(self, dim, d_state=64, d_conv=4, expand=2,
                 layer_idx=None, if_constraint=False, ):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)
        self.mamba = Mamba(
            d_model=dim,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand,  # Block expansion factor
            layer_idx=layer_idx,
            if_constraint=if_constraint
        )

    def forward(self, x):
        # print('x',x.shape)
        B, L, C = x.shape
        x_norm = self.norm(x)
        x_mamba = self.mamba(x_norm)
        return x_mamba


def rand_bbox(size, lam, scale=1):
    W = size[1] // scale
    H = size[2] // scale
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int_(W * cut_rat)
    cut_h = np.int_(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


class PVT2FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.dwconv = DWConv(hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.fc2(x)
        return x


class FFN(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class ClassBlock(nn.Module):
    def __init__(self, dim, mlp_ratio, norm_layer=nn.LayerNorm, cm_type='mlp',
                 layer_idx=None, if_constraint=False):
        super().__init__()
        # self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = MambaLayer(dim, layer_idx=layer_idx, if_constraint=if_constraint)  # MambaBlock(d_model=dim)
        if cm_type == 'EinFFT':
            self.mlp = EinFFT(dim)
        else:
            self.mlp = FFN(dim, int(dim * mlp_ratio))
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        cls_embed = x[:, :1]
        cls_embed = cls_embed + self.attn(x[:, :1])
        cls_embed = cls_embed + self.mlp(self.norm2(cls_embed), H, W)
        return torch.cat([cls_embed, x[:, 1:]], dim=1)


class Block_mamba(nn.Module):
    def __init__(self,
                 dim,
                 mlp_ratio,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 sr_ratio=1,
                 cm_type='mlp',
                 layer_idx=None,
                 if_constraint=False,
                 ):
        super().__init__()
        # self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = MambaLayer(dim, layer_idx=layer_idx, if_constraint=if_constraint)  # MambaBlock(d_model=dim)
        if cm_type == 'EinFFT':
            self.mlp = EinFFT(dim)
        else:
            self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio))
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(x))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
        return x


class DownSamples(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm = nn.LayerNorm(out_channels)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


class Stem(nn.Module):
    def __init__(self, in_channels, stem_hidden_dim, out_channels):
        super().__init__()
        hidden_dim = stem_hidden_dim
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2,
                      padding=3, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1,
                      padding=1, bias=False),  # 112x112
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.proj = nn.Conv2d(hidden_dim,
                              out_channels,
                              kernel_size=3,
                              stride=2,
                              padding=1)
        self.norm = nn.LayerNorm(out_channels)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv(x)
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W


class SiMBA(nn.Module):
    def __init__(self,
                 in_chans=3,
                 num_classes=1000,
                 stem_hidden_dim=32,
                 embed_dims=[64, 128, 320, 448],
                 mlp_ratios=[8, 8, 4, 4],
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3],
                 sr_ratios=[4, 2, 1, 1],
                 num_stages=4,
                 token_label=True,
                 constraint_layers=None,
                 **kwargs
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        if constraint_layers is None:
            # self.constraint_layers = []
            self.constraint_layers = [layer for layer in range(sum(depths))]  # [all]
        else:
            self.constraint_layers = constraint_layers  # []
        print('model.py constraint_layers:', self.constraint_layers)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0
        alpha = 5  #
        depth_idx = 0
        for i in range(num_stages):
            if i == 0:
                patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i])
            else:
                patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i])

            block = nn.ModuleList([Block_mamba(
                dim=embed_dims[i],
                mlp_ratio=mlp_ratios[i],
                drop_path=dpr[cur + j],
                norm_layer=norm_layer,
                sr_ratio=sr_ratios[i],
                cm_type='mlp',
                layer_idx=(depth_idx + j),
                if_constraint=(depth_idx + j) in self.constraint_layers,
            )  # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
                for j in range(depths[i])])
            depth_idx = depth_idx + depths[i]

            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        post_layers = ['ca']
        self.post_network = nn.ModuleList([
            ClassBlock(
                dim=embed_dims[-1],
                mlp_ratio=mlp_ratios[-1],
                norm_layer=norm_layer,
                cm_type='mlp',
                layer_idx=depth_idx,
                if_constraint=depth_idx in self.constraint_layers,
            )  # Change this to run EinFFT based Channel Mixer, cm_type='EinFFT'
            for _ in range(len(post_layers))
        ])

        # classification head
        self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
        ##################################### token_label #####################################
        self.return_dense = token_label  # True
        self.mix_token = token_label  # True
        self.beta = 1.0
        self.pooling_scale = 8
        if self.return_dense:
            self.aux_head = nn.Linear(
                embed_dims[-1],
                num_classes) if num_classes > 0 else nn.Identity()
        ##################################### token_label #####################################

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward_cls(self, x, H, W):
        B, N, C = x.shape
        cls_tokens = x.mean(dim=1, keepdim=True)
        # print('x.shape:', x.shape)  # (b, s, d)
        x = torch.cat((cls_tokens, x), dim=1)
        # print('cat x.shape:', x.shape)  # (b, s+1, d)
        for block in self.post_network:
            x = block(x, H, W)
        # print('post x.shape:', x.shape)  # (b, s+1, d)
        return x

    def forward_features(self, x):
        B = x.shape[0]
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)

            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x = self.forward_cls(x, H, W)[:, 0]
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)
        return x

    def forward(self, x):
        if not self.return_dense:  # not True
            x = self.forward_features(x)
            x = self.head(x)
            return x
        else:
            x, H, W = self.forward_embeddings(x)
            # print(x.shape, H, W)
            # mix token, see token labeling for details.
            if self.mix_token and self.training:
                lam = np.random.beta(self.beta, self.beta)
                patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[
                    2] // self.pooling_scale
                bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
                temp_x = x.clone()
                sbbx1, sbby1, sbbx2, sbby2 = self.pooling_scale * bbx1, self.pooling_scale * bby1, \
                                             self.pooling_scale * bbx2, self.pooling_scale * bby2
                temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
                x = temp_x
            else:
                bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
            x = self.forward_tokens(x, H, W)
            x_cls = self.head(x[:, 0])
            x_aux = self.aux_head(
                x[:, 1:]
            )  # generate classes in all feature tokens, see token labeling

            if not self.training:
                return x_cls + 0.5 * x_aux.max(1)[0]

            # if self.mix_token and self.training:  # reverse "mix token", see token labeling for details.
            #     x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
            #
            #     temp_x = x_aux.clone()
            #     temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
            #     x_aux = temp_x
            #
            #     x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
            #
            # return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
            return x_cls

    def forward_tokens(self, x, H, W):
        B = x.shape[0]
        x = x.view(B, -1, x.size(-1))

        # print('before stage x.shape:', x.shape)
        for i in range(self.num_stages):
            if i != 0:
                patch_embed = getattr(self, f"patch_embed{i + 1}")
                x, H, W = patch_embed(x)
            block = getattr(self, f"block{i + 1}")
            for blk in block:
                x = blk(x, H, W)
            if i != self.num_stages - 1:
                norm = getattr(self, f"norm{i + 1}")
                x = norm(x)
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        # print('after stage x.shape:', x.shape)
        x = self.forward_cls(x, H, W)
        norm = getattr(self, f"norm{self.num_stages}")
        x = norm(x)
        return x

    def forward_embeddings(self, x):
        patch_embed = getattr(self, f"patch_embed{0 + 1}")
        x, H, W = patch_embed(x)
        x = x.view(x.size(0), H, W, -1)
        return x, H, W


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        return x


@register_model
def simba_s_new(num_classes=1000, constraint_layers=None, pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim=16,
        embed_dims=[32, 32, 32, 64],
        mlp_ratios=[4, 4, 2, 2],
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        depths=[2, 3, 3, 2],
        sr_ratios=[4, 2, 1, 1],
        num_classes=num_classes,
        constraint_layers=constraint_layers,
        **kwargs)
    model.default_cfg = _cfg()
    return model


@register_model
def simba_s(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim=32,
        embed_dims=[64, 128, 320, 448],
        mlp_ratios=[8, 8, 4, 4],
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        depths=[3, 4, 6, 3],
        sr_ratios=[4, 2, 1, 1],
        **kwargs)
    model.default_cfg = _cfg()
    return model


@register_model
def simba_b(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim=64,
        embed_dims=[64, 128, 320, 512],
        mlp_ratios=[8, 8, 4, 4],
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        depths=[3, 4, 12, 3],
        sr_ratios=[4, 2, 1, 1],
        **kwargs)
    model.default_cfg = _cfg()
    return model


@register_model
def simba_l(pretrained=False, **kwargs):
    model = SiMBA(
        stem_hidden_dim=64,
        embed_dims=[96, 192, 384, 512],
        mlp_ratios=[8, 8, 4, 4],
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        depths=[3, 6, 18, 3],
        sr_ratios=[4, 2, 1, 1],
        **kwargs)
    model.default_cfg = _cfg()
    return model
