import time
import math
from functools import partial
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from collections import OrderedDict

from timm.models.registry import register_model

from perturb_style.ALOFT import ALOFT, ALOFT_image
from perturb_style.MixStyle import MixStyle
from perturb_style.DSU import DSU
from perturb_style.dropout_token_channel import dropout_token_channel

# from perturb_style.DWT import DWT
from perturb_style.SeqTokenAug import SeqTokenAug

from perturb_style.AttentionMask import AttentionMask

# import pywt

try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
    pass

# an alternative for mamba_ssm (in which causal_conv1d is needed)
try:
    from selective_scan import selective_scan_fn as selective_scan_fn_v1
    from selective_scan import selective_scan_ref as selective_scan_ref_v1
except:
    pass

DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"


def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
    """
    import numpy as np

    # fvcore.nn.jit_handles
    def get_flops_einsum(input_shapes, equation):
        np_arrs = [np.zeros(s) for s in input_shapes]
        optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
        for line in optim.split("\n"):
            if "optimized flop" in line.lower():
                # divided by 2 because we count MAC (multiply-add counted as one flop)
                flop = float(np.floor(float(line.split(":")[-1]) / 2))
                return flop

    assert not with_complex

    flops = 0  # below code flops = 0
    if False:
        ...
        """
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
        else:
            B = B.float()
            C = C.float()
        x = A.new_zeros((batch, dim, dstate))
        ys = []
        """

    flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
    if with_Group:
        flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
    else:
        flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
    if False:
        ...
        """
        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
        else:
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
            else:
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state = None
        """

    in_for_flops = B * D * N
    if with_Group:
        in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
    else:
        in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
    flops += L * in_for_flops
    if False:
        ...
        """
        for i in range(u.shape[2]):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            if not is_variable_C:
                y = torch.einsum('bdn,dn->bd', x, C)
            else:
                if C.dim() == 3:
                    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
                else:
                    y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
            if i == u.shape[2] - 1:
                last_state = x
            if y.is_complex():
                y = y.real * 2
            ys.append(y)
        y = torch.stack(ys, dim=2) # (batch dim L)
        """

    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L
    if False:
        ...
        """
        out = y if D is None else y + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        """

    return flops


class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x).permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class PatchMerging2D(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if (W % 2 != 0) or (H % 2 != 0):
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]

        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H // 2, W // 2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

class SS2D(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=16,
            # d_state="auto", # 20240109
            d_conv=3,
            expand=2,
            dt_rank="auto",
            dt_min=0.001,
            dt_max=0.1,
            dt_init="random",
            dt_scale=1.0,
            dt_init_floor=1e-4,
            dropout=0.,
            conv_bias=True,
            bias=False,
            device=None,
            dtype=None,

            spatial_aug_flag=0,
            SeqTokenAug_flag=0,
            SeqTokenAug_p=0.5,
            SeqTokenAug_token_prob=0.35,
            SeqTokenAug_batch_prob=0.5,
            SeqTokenAug_token_attention_flag=0,
            **kwargs,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model  # 96
        self.d_state = d_state
        # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)  # 96 * 2
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
        self.act = nn.SiLU()

        K_all = 4
        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
        )

        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                         **factory_kwargs),
        )

        self.x_proj_extra = ()
        self.dt_projs_extra = ()
        if self.shuffle_scan_flag == 1:
            shuffle_linear = (
                nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            )
            shuffle_dt = (
                self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                             **factory_kwargs),
            )
            self.x_proj_extra += shuffle_linear
            self.dt_projs_extra += shuffle_dt
            K_all += 1

        self.SeqTokenAug_flag = SeqTokenAug_flag
        self.spatial_aug_flag = spatial_aug_flag
        if self.spatial_aug_flag != 0 or self.SeqTokenAug_flag != 0:
            if self.spatial_aug_flag == 1:
                # MixStyle
                if self.SeqTokenAug_flag == 1:
                    self.spatial_aug = MixStyle(p=1.0)
                else:
                    self.spatial_aug = MixStyle()
            elif self.spatial_aug_flag == 2:
                # DSU
                if self.SeqTokenAug_flag == 1:
                    self.spatial_aug = DSU(p=1.0)
                else:
                    self.spatial_aug = DSU()
            elif self.spatial_aug_flag == 3:
                # ALOFT
                if self.SeqTokenAug_flag == 1:
                    self.spatial_aug = ALOFT_image(p=1.0, mask_size=0.5, factor=0.8, mask_or_model=1)
                else:
                    self.spatial_aug = ALOFT_image(p=0.5, mask_size=0.5, factor=0.8, mask_or_model=1)

        self.SeqTokenAug_token_attention_flag = SeqTokenAug_token_attention_flag
        if self.SeqTokenAug_flag != 0:
            self.SeqTokenAug = SeqTokenAug(
                p=SeqTokenAug_p,
                aug_token_prob=SeqTokenAug_token_prob,
                batch_prob=SeqTokenAug_batch_prob,
                token_attention_flag=SeqTokenAug_token_attention_flag,
            )

        if self.slant_scan_flag == 1:
            slant_linears = (
                nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
                nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
            )
            slant_dts = (
                self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                             **factory_kwargs),
                self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
                             **factory_kwargs),
            )
            if self.slant_scan_K == 2:
                self.x_proj_extra += slant_linears
                self.dt_projs_extra += slant_dts
                K_all += 2
            elif self.slant_scan_K == 4:
                self.x_proj_extra += slant_linears + slant_linears
                self.dt_projs_extra += slant_dts + slant_dts
                K_all += 4

        self.K_all = K_all
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))  # (K=4, N, inner)
        del self.x_proj

        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0))  # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0))  # (K=4, inner)
        del self.dt_projs

        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True)  # (K=4, D, N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True)  # (K=4, D, N)

        if self.K_all != 4:
            self.x_proj_extra_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj_extra], dim=0))  # (K, N, inner)
            del self.x_proj_extra

            self.dt_projs_extra_weight = nn.Parameter(
                torch.stack([t.weight for t in self.dt_projs_extra], dim=0))  # (K=4, inner, rank)
            self.dt_projs_extra_bias = nn.Parameter(
                torch.stack([t.bias for t in self.dt_projs_extra], dim=0))  # (K, inner)
            del self.dt_projs_extra

            self.A_logs_extra = self.A_log_init(self.d_state, self.d_inner, copies=K_all - 4, merge=True)  # (K=4, D, N)
            self.Ds_extra = self.D_init(self.d_inner, copies=K_all - 4, merge=True)  # (K=4, D, N)

        # self.selective_scan = selective_scan_fn
        self.forward_core = self.forward_corev0

        self.out_norm = nn.LayerNorm(self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
                **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank ** -0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True

        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

    @staticmethod
    def diagonal_gather(x):
        B, C, H, W = x.size()
        # get the elements of antidiagonal line
        shift = torch.arange(H, device=x.device).unsqueeze(1)  # [[0], [1], ..., [H-1]]]
        index = (torch.arange(W, device=x.device) + shift) % W   # [[0, 1, ..., W-1], [1, 2, ..., 0], [2, 3, ..., 1]]
        expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
        return x.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)

    @staticmethod
    def diagonal_scatter(tensor_flat, original_shape):
        # tensor_flat: B, C, L
        # recover the antidiagonal line
        B, C, H, W = original_shape
        shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # [[0], [1], ..., [H-1]]]
        index = (torch.arange(W, device=tensor_flat.device) + shift) % W  # [[0, 1, ..., W-1], [1, 2, ..., 0], [2, 3, ..., 1]]
        expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
        result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
        tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
        result_tensor.scatter_(3, expanded_index, tensor_reshaped)
        return result_tensor

    @staticmethod
    def antidiagonal_gather(x):
        B, C, H, W = x.size()
        # get the elements of antidiagonal line
        shift = torch.arange(H, device=x.device).unsqueeze(1)  # [[0], [1], ..., [H-1]]]
        index = (torch.arange(W,
                              device=x.device) - shift) % W  # [[0, 1, ..., W-1], [W-1, 0, 1, ...], [W-2, W-1, 0, 1, ...]]
        expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
        return x.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)

    @staticmethod
    def antidiagonal_scatter(tensor_flat, original_shape):
        # tensor_flat: B, C, L
        # recover the antidiagonal line
        B, C, H, W = original_shape
        shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1)  # [[0], [1], ..., [H-1]]]
        index = (torch.arange(W, device=tensor_flat.device) - shift) % W  # [[0, 1, ..., W-1], [W-1, 0, 1, ...], [W-2, W-1, 0, 1, ...]]
        expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)

        result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
        tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
        result_tensor.scatter_(3, expanded_index, tensor_reshaped)
        return result_tensor

    def forward_corev0(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn

        B, C, H, W = x.shape
        L = H * W
        K = self.K_all

        # BxCxHW, BxCxWH -> Bx2xCxL
        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
                             dim=1).view(B, 2, -1, L)
        # Bx2xCxL, Bx2xCxL -> Bx4xCxL
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)  # (b, k, d, l)

        x_proj_weight_all = self.x_proj_weight
        dt_projs_weight_all = self.dt_projs_weight
        dt_projs_bias_all = self.dt_projs_bias

        if self.SeqTokenAug_flag != 0 and self.training:
            if self.SeqTokenAug_token_attention_flag == 0:
                xs_aug = self.spatial_aug(xs)
                xs = self.SeqTokenAug(xs, xs_aug)
            elif self.SeqTokenAug_token_attention_flag == 1:
                xs_aug = self.spatial_aug(xs)
                xs = self.SeqTokenAug(xs, xs_aug, Bx=xs)
            else:
                x_dbl_temp = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), x_proj_weight_all)
                # _, Bs, _ = torch.split(x_dbl_temp, [self.dt_rank, self.d_state, self.d_state], dim=2)
                dts, Bs, Cs = torch.split(x_dbl_temp, [self.dt_rank, self.d_state, self.d_state], dim=2)
                Bs = Bs.float().view(B, K, -1, L)  # (b, k, d_state, l)
                Cs = Cs.float().view(B, K, -1, L)  # (b, k, d_state, l)
                dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), dt_projs_weight_all)

                xs_aug = self.spatial_aug(xs)
                # xs = self.SeqTokenAug(xs, xs_aug, Bx=Bs)
                deltaB_u = torch.einsum('bknl,bkdl,bknl,bkdl->bkdln', Cs, dts, Bs, xs)
                xs = self.SeqTokenAug(xs, xs_aug, Bx=deltaB_u.mean(dim=-1))
                del deltaB_u

                del x_dbl_temp, dts, Bs, Cs

        # print(xs.shape, x_proj_weight_all.shape)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), x_proj_weight_all)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), dt_projs_weight_all)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L)  # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L)  # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L)  # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L)  # (b, k, d_state, l)

        if self.K_all != 4:
            Ds = torch.cat([self.Ds, self.Ds_extra], dim=0).float().view(-1)  # (k * d)
            As = -torch.exp(torch.cat([self.A_logs, self.A_logs_extra], dim=0).float()).view(-1, self.d_state)
        else:
            Ds = self.Ds.float().view(-1)  # (k * d)
            As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)

        dt_projs_bias = dt_projs_bias_all.float().view(-1)  # (k * d)

        out_y = self.selective_scan(
            xs, dts,
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return_y = [out_y[:, 0], inv_y[:, 0], wh_y, invwh_y]

        return return_y


    # an alternative to forward_corev1
    def forward_corev1(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn_v1

        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
                             dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)  # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L)  # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L)  # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L)  # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L)  # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1)  # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1)  # (k * d)

        out_y = self.selective_scan(
            xs, dts,
            As, Bs, Cs, Ds,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    def forward(self, x: torch.Tensor, **kwargs):
        B, H, W, C = x.shape

        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1)  # (b, h, w, d)

        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.act(self.conv2d(x))  # (b, d, h, w)

        y_return = self.forward_core(x)

        assert y_return[0].dtype == torch.float32
        y = torch.sum(torch.stack(y_return, dim=0), dim=0, keepdim=False)   # BxCxL

        # y1, y2, y3, y4 = self.forward_core(x)
        # assert y1.dtype == torch.float32
        # y = y1 + y2 + y3 + y4

        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
        y = self.out_norm(y)
        y = y * F.silu(z)
        out = self.out_proj(y)
        if self.dropout is not None:
            out = self.dropout(out)

        return out


class VSSBlock(nn.Module):
    def __init__(
            self,
            hidden_dim: int = 0,
            drop_path: float = 0,
            norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
            attn_drop_rate: float = 0,
            d_state: int = 16,

            spatial_aug_flag: int = 0,
            SeqTokenAug_flag=0,
            SeqTokenAug_p=0.5,
            SeqTokenAug_token_prob=0.35,
            SeqTokenAug_batch_prob=0.5,
            SeqTokenAug_token_attention_flag=0,
            **kwargs,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state,

                                   spatial_aug_flag=spatial_aug_flag,
                                   SeqTokenAug_flag=SeqTokenAug_flag,
                                   SeqTokenAug_p=SeqTokenAug_p,
                                   SeqTokenAug_token_prob=SeqTokenAug_token_prob,
                                   SeqTokenAug_batch_prob=SeqTokenAug_batch_prob,
                                   SeqTokenAug_token_attention_flag=SeqTokenAug_token_attention_flag,
                                   **kwargs)
        self.drop_path = DropPath(drop_path)

    def forward(self, input: torch.Tensor):
        x = input + self.drop_path(self.self_attention(self.ln_1(input)))
        return x


class VSSLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
            self,
            dim,
            depth,
            attn_drop=0.,
            drop_path=0.,
            norm_layer=nn.LayerNorm,
            downsample=None,
            use_checkpoint=False,
            d_state=16,

            spatial_aug_flag=0,
            SeqTokenAug_flag=0,
            SeqTokenAug_p=0.5,
            SeqTokenAug_token_prob=0.35,
            SeqTokenAug_batch_prob=0.5,
            SeqTokenAug_token_attention_flag=0,

            **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,

                spatial_aug_flag=spatial_aug_flag,
                SeqTokenAug_flag=SeqTokenAug_flag,
                SeqTokenAug_p=SeqTokenAug_p,
                SeqTokenAug_token_prob=SeqTokenAug_token_prob,
                SeqTokenAug_batch_prob=SeqTokenAug_batch_prob,
                SeqTokenAug_token_attention_flag=SeqTokenAug_token_attention_flag,

            )
            for i in range(depth)])

        if True:  # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_()  # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))

            self.apply(_init_weights)

        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)

        if self.downsample is not None:
            x = self.downsample(x)

        return x


class Permute(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.args = args

    def forward(self, x: torch.Tensor):
        return x.permute(*self.args)

class VSSM(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2],
                 dims=[96, 192, 384, 768], d_state=16, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False,

                 spatial_aug_flag=0,
                 SeqTokenAug_flag=0,
                 SeqTokenAug_p=0.5,
                 SeqTokenAug_token_prob=0.35,
                 SeqTokenAug_batch_prob=0.5,
                 SeqTokenAug_token_attention_flag=0,

                 **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]
        self.num_features = dims[-1]
        self.dims = dims

        self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
                                        norm_layer=norm_layer if patch_norm else None)

        # WASTED absolute position embedding ======================
        self.ape = False
        # self.ape = False
        # drop_rate = 0.0
        if self.ape:
            self.patches_resolution = self.patch_embed.patches_resolution
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        feature_sizes = [56, 28, 14, 7]
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = VSSLayer(
                dim=dims[i_layer],
                depth=depths[i_layer],
                d_state=math.ceil(dims[0] / 6) if d_state is None else d_state,  # 20240109
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,

                spatial_aug_flag=spatial_aug_flag,
                SeqTokenAug_flag=SeqTokenAug_flag,
                SeqTokenAug_p=SeqTokenAug_p,
                SeqTokenAug_token_prob=SeqTokenAug_token_prob,
                SeqTokenAug_batch_prob=SeqTokenAug_batch_prob,
                SeqTokenAug_token_attention_flag=SeqTokenAug_token_attention_flag,
            )
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten(1)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module):
        """
        out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
        no fc.weight found in the any of the model parameters
        no nn.Embedding found in the any of the model parameters
        so the thing is, VSSBlock initialization is useless

        Conv2D is not intialized !!!
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        skip_list = []
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            skip_list.append(x)
            x = layer(x)
        return x, skip_list

    def forward_backbone(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        return x

    def forward(self, x):
        x = self.forward_backbone(x)
        x = self.norm(x).permute(0, 3, 1, 2)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x


@register_model
def vmamba_tiny(pretrained=False, pretrained_path=None, **kwargs):
    model = VSSM(
        patch_size=4, in_chans=3, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768],
        d_state=16, drop_rate=0.,
        attn_drop_rate=0., drop_path_rate=0.2,
        norm_layer=nn.LayerNorm, patch_norm=True,
        use_checkpoint=False, **kwargs)
    # model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(pretrained_path, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"], strict=False)
    return model


@register_model
def vmamba_small(pretrained=False, pretrained_path=None, **kwargs):
    model = VSSM(
        patch_size=4, in_chans=3, depths=[2, 2, 27, 2], dims=[96, 192, 384, 768],
        d_state=16, drop_rate=0.,
        attn_drop_rate=0., drop_path_rate=0.3,
        norm_layer=nn.LayerNorm, patch_norm=True,
        use_checkpoint=False, **kwargs)
    # model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(pretrained_path, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"], strict=False)
    return model


@register_model
def vmamba_base(pretrained=False, pretrained_path=None, **kwargs):
    model = VSSM(
        patch_size=4, in_chans=3, depths=[2, 2, 27, 2], dims=128,
        d_state=16, drop_rate=0.,
        attn_drop_rate=0., drop_path_rate=0.5,
        norm_layer=nn.LayerNorm, patch_norm=True,
        use_checkpoint=False, **kwargs)
    # model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(pretrained_path, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"], strict=False)
    return model
