import torch
import torch.nn as nn
from .utils import DropPath, PeriodicPad2d, Mlp, Linear
from .Attention import SD_attn, WindowAttention, HiLo, \
    Conv_attn, SD_attn_parallel, SPO_attn, Rentention_attn, MHA_withflash, \
        Vitattn_withflash, Swin_Attention_v2, Lora_Attention, Airsea_CrossAttention,\
        flash_SD_attn, Lora_Attention3d, Lora_Attention_withmask,SD_attn_withmoe

from .mlp import DWMlp, Lora_Mlp, GluMlp, Lora_GluMlp, Mlp_withmoe, Mlp_parallel



class Layer(nn.Module):
    def __init__(self, dim, depth, window_size, 
                num_heads=1, mlp_ratio=4., qkv_bias=True, 
                drop=0., attn_drop=0., drop_path=0., 
                norm_layer=nn.LayerNorm, layer_type="convnet_block",
                use_checkpoint=False, pre_norm=True) -> None:
        super().__init__()
        self.dim = dim
        self.depth = depth
        # for i in range(len(window_size)):
        #     if window_size[-i] == img_size[-i]:
        #         self.shift_size[-i] = 0

        self.blocks = nn.ModuleList()
        for i in range(depth):
            if layer_type == "convnet_block":
                block = Convnet_block(
                        dim=dim,
                        kernel_size=window_size,
                        drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                        layer_scale_init_value = 0,
                        norm_layer=norm_layer,
                    )
            elif layer_type == "window_block":
                block = Windowattn_block(
                        dim=dim,
                        window_size=window_size,
                        num_heads=num_heads,
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        drop=drop,
                        attn_drop=attn_drop,
                        drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                        norm_layer=norm_layer,
                        pre_norm=pre_norm,
                    )
            elif layer_type == "swin_block":
                block = Windowattn_block(
                    dim=dim,
                    window_size=window_size,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer,
                    pre_norm=pre_norm,
                    shift_size=[0,0] if i%2==0 else [i//2 for i in window_size]
                )
            # if use_checkpoint:
            #     block = checkpoint_wrapper(block, offload_to_cpu=True)
            self.blocks.append(block)
    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
    
        return x


class Convnet_block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, kernel_size=[4, 8], drop_path=0., layer_scale_init_value=1e-6, norm_layer=nn.LayerNorm):
        super().__init__()
        padding_size = [i // 2 for i in kernel_size]
        self.padding = PeriodicPad2d(padding_size)
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=0, groups=12) # depthwise conv
        self.norm = norm_layer(dim)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = x.permute(0, 3, 1, 2)
        x = self.padding(x)
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x

        x = input + self.drop_path(x)
        return x


class Originattn_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                attn_type="windowattn", pre_norm=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.pre_norm = pre_norm

        self.norm = norm_layer(dim)
        # self.GAU1 = Flash_attn(dim, window_size=self.window_size, uv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, expansion_factor=2, attn_type='lin')

        if attn_type == "origin_attn":
            self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=attn_drop, bias=qkv_bias, batch_first=True)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    

    def forward(self, x):
        shortcut = x
        # partition windows

        if self.pre_norm:
            x = self.norm(x)
            x = shortcut + self.drop_path(self.attn(x, x, x)[0])
        else:
            x = self.norm(shortcut + self.drop_path(self.attn(x, x, x)[0]))

        # W-MSA/SW-MS

        if self.pre_norm:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = self.norm2(x + self.drop_path(self.mlp(x)))

        return x





class Windowattn_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                attn_type="windowattn", pre_norm=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.pre_norm = pre_norm
        self.attn_type = attn_type
        if "save_attn" in kwargs:
            self.save_attn = kwargs['save_attn']
        else:
            self.save_attn = False

        self.norm = norm_layer(dim)
        # self.GAU1 = Flash_attn(dim, window_size=self.window_size, uv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, expansion_factor=2, attn_type='lin')
        if attn_type == "windowattn":
            if "shift_size" not in kwargs:
                shift_size = [0, 0, 0]
            else:
                shift_size = kwargs["shift_size"]
            if "dilated_size" in kwargs:
                dilated_size = kwargs["dilated_size"]
            else:
                dilated_size = [1, 1, 1]
            self.attn = SD_attn(
                dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size)
        if attn_type == "retenattn":
            if "shift_size" not in kwargs:
                shift_size = [0, 0, 0]
            else:
                shift_size = kwargs["shift_size"]
            if "dilated_size" in kwargs:
                dilated_size = kwargs["dilated_size"]
            else:
                dilated_size = [1, 1, 1]
            self.attn = Rentention_attn(
                dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size)
        elif attn_type == "spoattn":
            if "shift_size" not in kwargs:
                shift_size = [0, 0, 0]
            else:
                shift_size = kwargs["shift_size"]
            if "dilated_size" in kwargs:
                dilated_size = kwargs["dilated_size"]
            else:
                dilated_size = [1, 1, 1]
            self.attn = SPO_attn(
                dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size)
        elif attn_type == "convattn":
            self.attn = Conv_attn(dim, window_size, num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        elif attn_type == "origin_attn":
            self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=attn_drop, bias=qkv_bias, batch_first=True)
        elif attn_type == "mha_withflash":
            self.attn = MHA_withflash(dim, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    

    

    def forward(self, x):
        shortcut = x
        # partition windows

        if self.pre_norm:
            if self.attn_type == "windowattn":
                x, save_attn = self.attn(self.norm(x))
            else:
                x = self.attn(self.norm(x))
            
            x = shortcut + self.drop_path(x)

            # x = shortcut + self.drop_path(self.attn(self.norm(x)))
        else:
            if self.attn_type == "windowattn":
                x, save_attn = self.attn(x)
            else:
                x = self.attn(x)
            
            x = self.norm(shortcut + self.drop_path(x))

            # x = self.norm(shortcut + self.drop_path(self.attn(x)))

        # W-MSA/SW-MS

        if self.pre_norm:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = self.norm2(x + self.drop_path(self.mlp(x)))
        
        if self.attn_type == "windowattn" and self.save_attn:
            return x, save_attn
        else:
            return x


class Hilo_Block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.ReLU, norm_layer=nn.LayerNorm, pre_norm=True,
                alpha=0.9) -> None:
        super().__init__()
        self.dim = dim
        self.window_size=window_size
        self.pre_norm = pre_norm

        self.norm1 = norm_layer(dim)
        self.attn = HiLo(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop,
                        proj_drop=drop, window_size=window_size, alpha=alpha)
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.convffn = DWMlp(dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    
    def forward(self, x):
        shortcut = x
        # partition windows

        if self.pre_norm:
            x = shortcut + self.drop_path(self.attn(self.norm1(x)))
        else:
            x = self.norm1(shortcut + self.drop_path(self.attn(x)))

        # W-MSA/SW-MS

        if self.pre_norm:
            x = x + self.drop_path(self.convffn(self.norm2(x)))
        else:
            x = self.norm2(x + self.drop_path(self.convffn(x)))
        return x

class ConvFFNBlock(nn.Module):
    """ Convolutional FFN Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, window_size=7, num_heads=1, 
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, alpha=0.5):
        super().__init__()
        self.dim = dim
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = DWMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)


    def forward(self, x):
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class Windowattn_block_withmoe(nn.Module):
    def __init__(self, dim, attr_len, window_size, attr_hidden_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                attn_type="windowattn", pre_norm=True, attn_use_moe=True, 
                mlp_use_moe=True, num_experts=1, expert_capacity=1., router_bias=True, 
                router_noise=1e-2, is_scale_prob=True, drop_tokens=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.pre_norm = pre_norm
        self.attn_use_moe = attn_use_moe
        self.mlp_use_moe = mlp_use_moe

        self.norm = norm_layer(dim)
        # self.GAU1 = Flash_attn(dim, window_size=self.window_size, uv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, expansion_factor=2, attn_type='lin')
        if attn_type == "windowattn":
            if "shift_size" not in kwargs:
                shift_size = [0, 0, 0]
            else:
                shift_size = kwargs["shift_size"]
            if "dilated_size" in kwargs:
                dilated_size = kwargs["dilated_size"]
            else:
                dilated_size = [1, 1, 1]
            if attn_use_moe:
                self.attn = SD_attn_withmoe(
                    dim, attr_len=attr_len, attr_hidden_size=attr_hidden_size, window_size=self.window_size, num_heads=num_heads,
                    qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size,
                    num_experts=num_experts, expert_capacity=expert_capacity, router_bias=router_bias, router_noise=router_noise,
                    is_scale_prob=is_scale_prob, drop_tokens=drop_tokens
                )
            else:
                self.attn = SD_attn(
                    dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                    attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size)
        elif attn_type == "convattn":
            raise NotImplementedError('moe convattn')

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        if mlp_use_moe:
            self.mlp = Mlp_withmoe(in_features=dim, attr_len=attr_len, attr_hidden_size=attr_hidden_size, hidden_features=mlp_hidden_dim,
            act_layer=act_layer, drop=drop, num_experts=num_experts, expert_capacity=expert_capacity, router_bias=router_bias,
            router_noise=router_noise, is_scale_prob=is_scale_prob, drop_tokens=drop_tokens)
        else:
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    

    def forward(self, x, attr=None):
        shortcut = x
        # partition windows

        if self.pre_norm:
            x = self.norm(x)
            if self.attn_use_moe:
                x, z_loss1, balance_loss1 = self.attn(x, attr)
            else:
                x = self.attn(x)
                z_loss1, balance_loss1 = 0, 0
            x = shortcut + self.drop_path(x)
            shortcut = x
            x = self.norm2(x)
            if self.mlp_use_moe:
                x, z_loss2, balance_loss2 = self.mlp(x, attr)
            else:
                x = self.mlp(x)
                z_loss2, balance_loss2 = 0, 0
            x = shortcut + self.drop_path(x)
        else:
            if self.attn_use_moe:
                x, z_loss1, balance_loss1 = self.attn(x, attr)
            else:
                x = self.attn(x)
                z_loss1, balance_loss1 = 0, 0
            x = self.norm(shortcut + self.drop_path(x))
            shortcut = x
            if self.mlp_use_moe:
                x, z_loss2, balance_loss2 = self.mlp(x, attr)
            else:
                x = self.mlp(x)
                z_loss2, balance_loss2 = 0, 0
            x = self.norm2(shortcut + self.drop_path(x))

        return x, [z_loss1, z_loss2], [balance_loss1, balance_loss2]




class Windowattn_parallelblock(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                attn_type="windowattn", pre_norm=True, use_cpu_initialization=True,
                **kwargs):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.pre_norm = pre_norm

        self.norm = norm_layer(dim)
        # self.GAU1 = Flash_attn(dim, window_size=self.window_size, uv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, expansion_factor=2, attn_type='lin')
        if attn_type == "windowattn":
            if "shift_size" not in kwargs:
                shift_size = [0, 0, 0]
            else:
                shift_size = kwargs["shift_size"]
            if "dilated_size" in kwargs:
                dilated_size = kwargs["dilated_size"]
            else:
                dilated_size = [1, 1, 1]
            self.attn = SD_attn_parallel(
                dim, window_size=self.window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                attn_drop=attn_drop, proj_drop=drop, shift_size=shift_size, dilated_size=dilated_size,
                use_cpu_initialization=use_cpu_initialization)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_parallel(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
                                use_cpu_initialization=use_cpu_initialization)
    

    

    def forward(self, x):
        shortcut = x
        # partition windows

        if self.pre_norm:
            x = shortcut + self.drop_path(self.attn(self.norm(x)))
        else:
            x = self.norm(shortcut + self.drop_path(self.attn(x)))

        # W-MSA/SW-MS

        if self.pre_norm:
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            x = self.norm2(x + self.drop_path(self.mlp(x)))

        return x



class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma



class Vit_ResPostblock(nn.Module):
    def __init__(self, dim, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                respost_norm=False, **kwargs):
        super().__init__()

        self.respost_norm = respost_norm

        self.norm1 = norm_layer(dim)
        self.attn = Vitattn_withflash(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self._init_weights()

    def _init_weights(self):
        nn.init.constant_(self.norm1.bias, 0)
        nn.init.constant_(self.norm1.weight, 0)
        nn.init.constant_(self.norm2.bias, 0)
        nn.init.constant_(self.norm2.weight, 0)
    def forward(self, x):
        if self.respost_norm:
            x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
            x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        else:
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))

        return x



class Scale_Swin_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="prenorm", pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, mlp_type="mlp", **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub

        self.norm1 = norm_layer(dim)
        self.attn = Swin_Attention_v2(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        if mlp_type == "mlp":
            self.mlp = Mlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                act_layer=act_layer,
                drop=drop,
            )
        else:
            self.mlp = GluMlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                drop=drop,
            )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def _init_weights(self):
        if self.resnorm_type == "prepost_norm":
            try:
                bias = self.norm1.bias
                nn.init.constant_(self.norm1.bias, 0)
                nn.init.constant_(self.norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(self.norm1.norm.bias, 0)
                    nn.init.constant_(self.norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = self.norm1.bias
                nn.init.constant_(self.norm1.bias, 0)
                nn.init.constant_(self.norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(self.norm1.norm.bias, 0)
                    nn.init.constant_(self.norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return




    def forward(self, x):
        if self.resnorm_type == "prepost_norm":
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(self.norm1(self.attn(x)))
                x = x + 0. * self.ls2(self.norm2(self.mlp(x)))
            else:
                x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
                x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        elif self.resnorm_type == "pre_norm":
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(self.attn(self.norm1(x)))
                x = x + 0. * self.ls2(self.mlp(self.norm2(x)))
            else:
                x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
                x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        else:
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = self.norm1(x + 0. * self.ls1(self.attn(x)))
                x = self.norm2(x + 0. * self.ls2(self.mlp(x)))
            else:
                x = self.norm1(x + self.drop_path1(self.ls1(self.attn(x))))
                x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x))))

        return x


class Swin_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                respost_norm=True, pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, **kwargs):
        super().__init__()

        self.respost_norm = respost_norm
        self.sample_sub = sample_sub

        self.norm1 = norm_layer(dim)
        self.attn = Swin_Attention_v2(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def _init_weights(self):
        if self.respost_norm:
            nn.init.constant_(self.norm1.bias, 0)
            nn.init.constant_(self.norm1.weight, 0)
            nn.init.constant_(self.norm2.bias, 0)
            nn.init.constant_(self.norm2.weight, 0)
        else:
            nn.init.constant_(self.norm1.bias, 0)
            nn.init.constant_(self.norm1.weight, 1)
            nn.init.constant_(self.norm2.bias, 0)
            nn.init.constant_(self.norm2.weight, 1)



    def forward(self, x):
        if self.respost_norm:
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(self.norm1(self.attn(x)))
                x = x + 0. * self.ls2(self.norm2(self.mlp(x)))
            else:
                x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
                x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        else:
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(self.attn(self.norm1(x)))
                x = x + 0. * self.ls2(self.mlp(self.norm2(x)))
            else:
                x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
                x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class Swin_block_v2(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                respost_norm=True, pretrained_window_size=[0,0],
                shift_size=[0,0], **kwargs):
        super().__init__()

        self.respost_norm = respost_norm

        self.norm1 = norm_layer(dim)
        self.attn = Swin_Attention_v2(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def _init_weights(self):
        nn.init.constant_(self.norm1.bias, 0)
        nn.init.constant_(self.norm1.weight, 0)
        nn.init.constant_(self.norm2.bias, 0)
        nn.init.constant_(self.norm2.weight, 0)


    def forward(self, x):
        if self.respost_norm:
            x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
            x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        else:
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))

        return x



class Lora_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="prenorm", pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True, origin_network=False, 
                 **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub
        self.origin_network = origin_network
        if origin_network:
            self.norm = norm_layer(dim)
            if r > 0:
                self.norm.weight.requires_grad = False
                self.norm.bias.requires_grad = False
        else:
            self.norm1 = norm_layer(dim)
            if r > 0:
                self.norm1.weight.requires_grad = False
                self.norm1.bias.requires_grad = False
        if len(window_size) == 2:
            attn_fn = Lora_Attention
        else:
            attn_fn = Lora_Attention3d
        self.attn = attn_fn(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        if r > 0:
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        self.mlp = Lora_Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def _init_weights(self):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1
        if self.resnorm_type == "prepost_norm":
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return




    def forward(self, x, mask=None):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1

        if self.resnorm_type == "prepost_norm":
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(norm1(self.attn(x, x_mask=mask)))
                x = x + 0. * self.ls2(self.norm2(self.mlp(x, x_mask=mask)))
            else:
                x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
        elif self.resnorm_type == "pre_norm":
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = x + 0. * self.ls1(self.attn(norm1(x), x_mask=mask))
                x = x + 0. * self.ls2(self.mlp(self.norm2(x), x_mask=mask))
            else:
                x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
        else:
            if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                x = norm1(x + 0. * self.ls1(self.attn(x, x_mask=mask)))
                x = self.norm2(x + 0. * self.ls2(self.mlp(x, x_mask=mask)))
            else:
                x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))

        return x
    
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1).unsqueeze(1)) + shift.unsqueeze(1).unsqueeze(1)


class DiTLora_block(nn.Module):
    def __init__(self, dim, window_size, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="pre_norm", pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, r=0, lora_alpha=1, lora_dropout=0,
                fan_in_fan_out=False, merge_weights=True, origin_network=False, 
                use_t=True, adaLN_init=False, **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub
        self.origin_network = origin_network
        self.window_size = window_size
        if origin_network:
            self.norm = norm_layer(dim, elementwise_affine=adaLN_init)
            if r > 0:
                self.norm.weight.requires_grad = False
                self.norm.bias.requires_grad = False
        else:
            self.norm1 = norm_layer(dim, elementwise_affine=adaLN_init)
            if r > 0:
                self.norm1.weight.requires_grad = False
                self.norm1.bias.requires_grad = False
        if len(window_size) == 2:
            attn_fn = Lora_Attention_withmask
        else:
            attn_fn = Lora_Attention3d
        self.attn = attn_fn(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim, elementwise_affine=adaLN_init)
        if r > 0:
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Lora_Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=approx_gelu,
            drop=drop, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        if use_t:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                # nn.Linear(dim, 6 * dim, bias=True)
                Linear(dim, 6 * dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
            )
            if adaLN_init:
                self.adaLN_modulation[1].weight.data.fill_(0)
                self.adaLN_modulation[1].bias.data[:2*dim].fill_(0)
                self.adaLN_modulation[1].bias.data[2*dim:3*dim].fill_(1)
                self.adaLN_modulation[1].bias.data[3*dim:5*dim].fill_(0)
                self.adaLN_modulation[1].bias.data[5*dim:6*dim].fill_(1)
        else:
            self.adaLN_modulation = None


        # self.apply(self._init_weights)
        self._init_weights()

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def _init_weights(self):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1
        if self.resnorm_type == "prepost_norm":
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return




    def forward(self, x, mask=None, condition=None):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1

        if condition is None:
            if self.resnorm_type == "prepost_norm":
                if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                    x = x + 0. * self.ls1(norm1(self.attn(x, x_mask=mask)))
                    x = x + 0. * self.ls2(self.norm2(self.mlp(x, x_mask=mask)))
                else:
                    x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                    x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
            elif self.resnorm_type == "pre_norm":
                if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                    x = x + 0. * self.ls1(self.attn(norm1(x), x_mask=mask))
                    x = x + 0. * self.ls2(self.mlp(self.norm2(x), x_mask=mask))
                else:
                    x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
            else:
                if self.sample_sub and torch.rand(1)[0] < 0.5 and self.training:
                    x = norm1(x + 0. * self.ls1(self.attn(x, x_mask=mask)))
                    x = self.norm2(x + 0. * self.ls2(self.mlp(x, x_mask=mask)))
                else:
                    x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                    x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))


        else:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(condition).chunk(6, dim=1)
            if len(self.window_size) == 3:
                gate_msa = gate_msa.unsqueeze(1)
                gate_mlp = gate_mlp.unsqueeze(1)
                shift_msa = shift_msa.unsqueeze(1)
                shift_mlp = shift_mlp.unsqueeze(1)
                scale_msa = scale_msa.unsqueeze(1)
                scale_mlp = scale_mlp.unsqueeze(1)
            # gate_msa = 0 * gate_msa + 1
            # gate_mlp = 0 * gate_mlp + 1
            # shift_msa = 0 * shift_msa + 1
            # scale_msa = 0 * scale_msa + 1
            # shift_mlp = 0 * shift_mlp + 1
            # scale_mlp = 0 * scale_mlp + 1
            if self.resnorm_type == "pre_norm":
                x = x + gate_msa.unsqueeze(1).unsqueeze(1) * self.ls1(self.attn(modulate(norm1(x), shift_msa, scale_msa), x_mask=mask))
                x = x + gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), x_mask=mask))
            else:
                x = norm1(x + gate_msa.unsqueeze(1).unsqueeze(1) * self.ls1(self.attn(modulate(x, shift_msa, scale_msa), x_mask=mask)))
                x = self.norm2(x + gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(x, shift_mlp, scale_mlp), x_mask=mask)))

        return x



class DiTLoraairsea_block(nn.Module):
    def __init__(self, dim, window_size1, window_size2, img_size=[], num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="pre_norm", pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, r=0, lora_alpha=1, lora_dropout=0,
                fan_in_fan_out=False, merge_weights=True, crossattn_parallel=True,
                use_cross_norm = False, use_cross=False, origin_network=False, 
                use_t=True, adaLN_init=False, **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub
        self.origin_network = origin_network
        self.window_size = window_size1
        self.crossattn_parallel = crossattn_parallel
        self.use_cross = use_cross
        if origin_network:
            self.norm = norm_layer(dim, elementwise_affine=adaLN_init)
            if r > 0 and adaLN_init:
                self.norm.weight.requires_grad = False
                self.norm.bias.requires_grad = False
        else:
            self.norm1 = norm_layer(dim, elementwise_affine=adaLN_init)
            if r > 0 and adaLN_init:
                self.norm1.weight.requires_grad = False
                self.norm1.bias.requires_grad = False
        if len(window_size1) == 2:
            attn_fn = Lora_Attention_withmask
        else:
            attn_fn = Lora_Attention3d
        self.attn = attn_fn(
            dim,
            window_size=window_size1,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        if use_cross:
            self.cross_attn = Airsea_CrossAttention(
                dim,
                window_size1=img_size,
                window_size2=window_size2,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
                pretrained_window_size=pretrained_window_size,
                use_qknorm=qk_norm,
                use_flash=use_flash,
                posembed_type=posembed_type
            )
            self.cross_norm = norm_layer(dim) if use_cross_norm else nn.Identity()

        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim, elementwise_affine=adaLN_init)
        if r > 0 and adaLN_init:
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Lora_Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=approx_gelu,
            drop=drop, 
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        if use_t:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                # nn.Linear(dim, 6 * dim, bias=True)
                Linear(dim, 6 * dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
            )
            if adaLN_init:
                self.adaLN_modulation[1].weight.data.fill_(0)
                self.adaLN_modulation[1].bias.data[:2*dim].fill_(0)
                self.adaLN_modulation[1].bias.data[2*dim:3*dim].fill_(1)
                self.adaLN_modulation[1].bias.data[3*dim:5*dim].fill_(0)
                self.adaLN_modulation[1].bias.data[5*dim:6*dim].fill_(1)
            if use_cross:
                self.adaLN_modulation_cross = nn.Sequential(
                    nn.SiLU(),
                    # nn.Linear(dim, 6 * dim, bias=True)
                    Linear(dim, 3 * dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
                )
                if adaLN_init:
                    self.adaLN_modulation_cross[1].weight.data.fill_(0)
                    self.adaLN_modulation_cross[1].bias.data[:2*dim].fill_(0)
                    self.adaLN_modulation_cross[1].bias.data[2*dim:3*dim].fill_(1)
        else:
            self.adaLN_modulation = None


        # self.apply(self._init_weights)
        self._init_weights()

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def _init_weights(self):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1
        if self.resnorm_type == "prepost_norm":
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return




    def forward(self, x, mask=None, condition=None, cross_x=None):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1

        if condition is None:
            if self.use_cross and cross_x is not None:
                cross_x = cross_x.to(x)
                if self.resnorm_type == "prepost_norm":
                    if self.crossattn_parallel:
                        x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask)) + self.cross_norm(self.cross_attn(x, cross_x))))
                    else:
                        x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                        x = x + self.cross_norm(self.cross_attn(x, cross_x))
                    x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
                elif self.resnorm_type == "pre_norm":
                    if self.crossattn_parallel:
                        x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask) + self.cross_attn(self.cross_norm(x), self.cross_norm(cross_x))))
                    else:
                        x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                        x = x + self.cross_attn(self.cross_norm(x), self.cross_norm(cross_x))
                    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
                else:
                    if self.crossattn_parallel:
                        x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask) + self.cross_attn(x, cross_x))))
                    else:
                        x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                        x = self.cross_norm(x + self.cross_attn(x, cross_x))
                    x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))
            else:
                if self.resnorm_type == "prepost_norm":
                    x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                    x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
                elif self.resnorm_type == "pre_norm":
                    x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
                else:
                    x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                    x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))
            

        else:
            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(condition).chunk(6, dim=1)
            if self.use_cross:
                cross_shift_msa, cross_scale_msa, cross_gate_msa = self.adaLN_modulation_cross(condition).chunk(3, dim=1)
            if len(self.window_size) == 3:
                gate_msa = gate_msa.unsqueeze(1)
                gate_mlp = gate_mlp.unsqueeze(1)
                shift_msa = shift_msa.unsqueeze(1)
                shift_mlp = shift_mlp.unsqueeze(1)
                scale_msa = scale_msa.unsqueeze(1)
                scale_mlp = scale_mlp.unsqueeze(1)
                if self.use_cross:
                    cross_scale_msa = cross_scale_msa.unsqueeze(1)
                    cross_gate_msa = cross_gate_msa.unsqueeze(1)
                    cross_shift_msa = cross_shift_msa.unsqueeze(1)
            # gate_msa = 0 * gate_msa + 1
            # gate_mlp = 0 * gate_mlp + 1
            # shift_msa = 0 * shift_msa + 1
            # scale_msa = 0 * scale_msa + 1
            # shift_mlp = 0 * shift_mlp + 1
            # scale_mlp = 0 * scale_mlp + 1
            if self.use_cross and cross_x is not None:
                if self.resnorm_type == "pre_norm":
                    if self.crossattn_parallel:
                        x = x + self.drop_path1(self.ls1(gate_msa.unsqueeze(1).unsqueeze(1) * self.attn(modulate(norm1(x), shift_msa, scale_msa), x_mask=mask) + cross_gate_msa.unsqueeze(1).unsqueeze(1) * self.cross_attn(modulate(self.cross_norm(x), cross_shift_msa, cross_scale_msa), modulate(self.cross_norm(cross_x), cross_shift_msa, cross_scale_msa))))
                    else:
                        x = x + self.drop_path1(self.ls1(gate_msa.unsqueeze(1).unsqueeze(1) * self.attn(modulate(norm1(x), shift_msa, scale_msa), x_mask=mask)))
                        x = x + cross_gate_msa.unsqueeze(1).unsqueeze(1) * self.cross_attn(modulate(self.cross_norm(x), cross_shift_msa, cross_scale_msa), modulate(self.cross_norm(cross_x), cross_shift_msa, cross_scale_msa))
                    x = x + self.drop_path2(gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), x_mask=mask)))
                else:
                    if self.crossattn_parallel:
                        x = norm1(x + self.drop_path1(self.ls1(gate_msa.unsqueeze(1).unsqueeze(1) * self.attn(modulate(x, shift_msa, scale_msa), x_mask=mask) + cross_gate_msa.unsqueeze(1).unsqueeze(1) * self.cross_attn(modulate(x, cross_shift_msa, cross_scale_msa), modulate(cross_x, cross_shift_msa, cross_scale_msa)))))
                    else:
                        x = norm1(x + self.drop_path1(self.ls1(gate_msa.unsqueeze(1).unsqueeze(1) * self.attn(modulate(x, shift_msa, scale_msa), x_mask=mask))))
                        x = self.cross_norm(x + cross_gate_msa.unsqueeze(1).unsqueeze(1) * self.cross_attn(modulate(x, cross_shift_msa, cross_scale_msa), modulate(cross_x, cross_shift_msa, cross_scale_msa)))
                    x = self.norm2(x + self.drop_path2(gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(x, shift_mlp, scale_mlp), x_mask=mask))))
            else:
                if self.resnorm_type == "pre_norm":
                    x = x + gate_msa.unsqueeze(1).unsqueeze(1) * self.ls1(self.attn(modulate(norm1(x), shift_msa, scale_msa), x_mask=mask))
                    x = x + gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), x_mask=mask))
                else:
                    x = norm1(x + gate_msa.unsqueeze(1).unsqueeze(1) * self.ls1(self.attn(modulate(x, shift_msa, scale_msa), x_mask=mask)))
                    x = self.norm2(x + gate_mlp.unsqueeze(1).unsqueeze(1) * self.ls2(self.mlp(modulate(x, shift_mlp, scale_mlp), x_mask=mask)))

        return x


class AirSea_block(nn.Module):
    def __init__(self, dim, window_size1, window_size2, num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="prenorm", pretrained_window_size=[0,0],
                shift_size=[0,0], use_flash=False, posembed_type=None,
                sample_sub=False, r=0, lora_alpha=1, lora_dropout=0,
                fan_in_fan_out=False, merge_weights=True, crossattn_parallel=True,
                use_cross_norm = True, use_cross=True, origin_network=False, **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub
        self.crossattn_parallel = crossattn_parallel
        self.use_cross = use_cross

        self.origin_network = origin_network
        if origin_network:
            self.norm = norm_layer(dim)
            if r > 0:
                self.norm.weight.requires_grad = False
                self.norm.bias.requires_grad = False
        else:
            self.norm1 = norm_layer(dim)
            if r > 0:
                self.norm1.weight.requires_grad = False
                self.norm1.bias.requires_grad = False
        
        if len(window_size1) == 2:
            attn_fn = Swin_Attention_v2
        else:
            attn_fn = Lora_Attention3d
        self.attn = attn_fn(
            dim,
            window_size=window_size1,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=shift_size,
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type
        )
        if use_cross:
            self.cross_attn = Airsea_CrossAttention(
                dim,
                window_size1=window_size1,
                window_size2=window_size2,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
                pretrained_window_size=pretrained_window_size,
                use_qknorm=qk_norm,
                use_flash=use_flash,
                posembed_type=posembed_type
            )
            self.cross_norm = norm_layer(dim) if use_cross_norm else nn.Identity()


        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        if r > 0:
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        self.mlp = Lora_Mlp(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=drop,
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def _init_weights(self):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1
        if self.resnorm_type == "prepost_norm":
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = norm1.bias
                nn.init.constant_(norm1.bias, 0)
                nn.init.constant_(norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(norm1.norm.bias, 0)
                    nn.init.constant_(norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return




    def forward(self, x, cross_x=None, mask=None):
        if self.origin_network:
            norm1 = self.norm
        else:
            norm1 = self.norm1
        if self.use_cross and cross_x is not None:
            cross_x = cross_x.to(x)
            if self.resnorm_type == "prepost_norm":
                if self.crossattn_parallel:
                    x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask)) + self.cross_norm(self.cross_attn(x, cross_x))))
                else:
                    x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                    x = x + self.cross_norm(self.cross_attn(x, cross_x))
                x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
            elif self.resnorm_type == "pre_norm":
                if self.crossattn_parallel:
                    x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask) + self.cross_attn(self.cross_norm(x), self.cross_norm(cross_x))))
                else:
                    x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                    x = x + self.cross_attn(self.cross_norm(x), self.cross_norm(cross_x))
                x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
            else:
                if self.crossattn_parallel:
                    x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask) + self.cross_attn(x, cross_x))))
                else:
                    x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                    x = self.cross_norm(x + self.cross_attn(x, cross_x))
                x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))
        else:
            if self.resnorm_type == "prepost_norm":
                x = x + self.drop_path1(self.ls1(norm1(self.attn(x, x_mask=mask))))
                x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x, x_mask=mask))))
            elif self.resnorm_type == "pre_norm":
                x = x + self.drop_path1(self.ls1(self.attn(norm1(x), x_mask=mask)))
                x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x), x_mask=mask)))
            else:
                x = norm1(x + self.drop_path1(self.ls1(self.attn(x, x_mask=mask))))
                x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x, x_mask=mask))))
        
        return x


class SD_block(nn.Module):
    def __init__(self, dim, window_size, dilated_size=[1,1], 
                 num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="prenorm", pretrained_window_size=[0,0],
                window_shift=True, use_flash=False, posembed_type=None,
                sample_sub=False, r=0, lora_alpha=1, lora_dropout=0,
                fan_in_fan_out=False, merge_weights=True, dilate_attn_parallel=True,
                use_dilate=True, mlp_type="mlp", **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub
        self.dilate_attn_parallel = dilate_attn_parallel
        self.use_dilate = use_dilate

        self.norm1 = norm_layer(dim)
        if r > 0:
            self.norm1.weight.requires_grad = False
            self.norm1.bias.requires_grad = False
        
        self.attn = flash_SD_attn(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=[window_size[i]*dilated_size[i]//2 for i in range(len(window_size))] if window_shift else [0, 0],
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type, 
            dilated_size=dilated_size,
            r=r, 
            lora_alpha=lora_alpha, 
            lora_dropout=lora_dropout, 
            fan_in_fan_out=fan_in_fan_out, 
            merge_weights=merge_weights
        )
        # if use_dilate:
        #     if r > 0:
        #         for param in self.attn.parameters():
        #             param.requires_grad = False
        #     self.dilate_attn = flash_SD_attn(
        #         dim,
        #         window_size=dilated_size,
        #         num_heads=num_heads,
        #         qkv_bias=qkv_bias,
        #         attn_drop=attn_drop,
        #         proj_drop=drop,
        #         pretrained_window_size=pretrained_window_size,
        #         shift_size=[dilated_size[i]//2 for i in range(len(window_size))] if window_shift else [0, 0],
        #         use_qknorm=qk_norm,
        #         use_flash=use_flash,
        #         posembed_type=posembed_type,
        #         dilated_size=[1,1]
        #     )


        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        if r > 0:
            self.norm2.weight.requires_grad = False
            self.norm2.bias.requires_grad = False
        if mlp_type == "glumlp":
            self.mlp = Lora_GluMlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                drop=drop,
                r=r, 
                lora_alpha=lora_alpha, 
                lora_dropout=lora_dropout, 
                fan_in_fan_out=fan_in_fan_out, 
                merge_weights=merge_weights
            )
        else:
            self.mlp = Lora_Mlp(
                in_features=dim,
                hidden_features=int(dim * mlp_ratio),
                act_layer=act_layer,
                drop=drop,
                r=r, 
                lora_alpha=lora_alpha, 
                lora_dropout=lora_dropout, 
                fan_in_fan_out=fan_in_fan_out, 
                merge_weights=merge_weights
            )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # self.apply(self._init_weights)
        self._init_weights()

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def _init_weights(self):
        if self.resnorm_type == "prepost_norm":
            try:
                bias = self.norm1.bias
                nn.init.constant_(self.norm1.bias, 0)
                nn.init.constant_(self.norm1.weight, 0)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 0)
            except Exception as err:
                try:
                    nn.init.constant_(self.norm1.norm.bias, 0)
                    nn.init.constant_(self.norm1.norm.weight, 0)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 0)
                except:
                    return

        else:
            try:
                bias = self.norm.bias
                nn.init.constant_(self.norm1.bias, 0)
                nn.init.constant_(self.norm1.weight, 1)
                nn.init.constant_(self.norm2.bias, 0)
                nn.init.constant_(self.norm2.weight, 1)
            except Exception as err:
                try:
                    nn.init.constant_(self.norm1.norm.bias, 0)
                    nn.init.constant_(self.norm1.norm.weight, 1)
                    nn.init.constant_(self.norm2.norm.bias, 0)
                    nn.init.constant_(self.norm2.norm.weight, 1)
                except:
                    return



    def forward(self, x):
        # if self.use_dilate:
        #     if self.resnorm_type == "prepost_norm":
        #         if self.dilate_attn_parallel:
        #             x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x) + self.dilate_attn(x))))
        #         else:
        #             x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
        #             x = x + self.dilate_attn(x)
        #         x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        #     elif self.resnorm_type == "pre_norm":
        #         if self.dilate_attn_parallel:
        #             x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)) + self.dilate_attn(self.norm1(x))))
        #         else:
        #             x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        #             x = x + self.dilate_attn(self.norm1(x))
        #         x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        #     else:
        #         if self.dilate_attn_parallel:
        #             x = self.norm1(x + self.drop_path1(self.ls1(self.attn(x) + self.dilate_attn(x))))
        #         else:
        #             x = self.norm1(x + self.drop_path1(self.ls1(self.attn(x))))
        #             x = x + self.dilate_attn(x)
        #         x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x))))
        # else:
        if self.resnorm_type == "prepost_norm":
            x = x + self.drop_path1(self.ls1(self.norm1(self.attn(x))))
            x = x + self.drop_path2(self.ls2(self.norm2(self.mlp(x))))
        elif self.resnorm_type == "pre_norm":
            x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        else:
            x = self.norm1(x + self.drop_path1(self.ls1(self.attn(x))))
            x = self.norm2(x + self.drop_path2(self.ls2(self.mlp(x))))
        
        return x

class attn_block(nn.Module):
    def __init__(self, dim, window_size, dilated_size=[1,1], 
                 num_heads=1, mlp_ratio=4., 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 
                act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                qk_norm=False, init_values=None,
                resnorm_type="prenorm", pretrained_window_size=[0,0],
                window_shift=True, use_flash=False, posembed_type=None,
                sample_sub=False, **kwargs):
        super().__init__()

        self.resnorm_type = resnorm_type
        self.sample_sub = sample_sub

        self.norm1 = norm_layer(dim)
        
        self.attn = flash_SD_attn(
            dim,
            window_size=window_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
            pretrained_window_size=pretrained_window_size,
            shift_size=[0, 0],
            use_qknorm=qk_norm,
            use_flash=use_flash,
            posembed_type=posembed_type, 
            dilated_size=[1,1]
        )
     
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

      

        self._init_weights()


    def _init_weights(self):
        nn.init.constant_(self.attn.proj.bias, 0)
        nn.init.constant_(self.attn.proj.weight, 0)


    def forward(self, x):
 
        x = x + self.drop_path1(self.attn(self.norm1(x)))
        return x   


class CLSTM_cell(nn.Module):
    """ConvLSTMCell
    """
    def __init__(self, shape, input_channels, filter_size, num_features):
        super(CLSTM_cell, self).__init__()

        self.shape = shape  # H, W
        self.input_channels = input_channels
        self.filter_size = filter_size
        self.num_features = num_features
        # in this way the output has the same size
        self.padding = (filter_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_channels + self.num_features,
                      4 * self.num_features, self.filter_size, 1,
                      self.padding),
            nn.GroupNorm(4 * self.num_features // 32, 4 * self.num_features))

    def forward(self, inputs=None, hidden_state=None, seq_len=10, t=None):
        #  seq_len=10 for moving_mnist
        if hidden_state is None:
            hx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
                             self.shape[1]).cuda()
            cx = torch.zeros(inputs.size(1), self.num_features, self.shape[0],
                             self.shape[1]).cuda()
        else:
            hx, cx = hidden_state
        output_inner = []
        for index in range(seq_len):
            if inputs is None:
                x = torch.zeros(hx.size(0), self.input_channels, self.shape[0],
                                self.shape[1]).cuda()
            else:
                x = inputs[index, ...]
            if t is None:
                combined = torch.cat((x, hx), 1)
            else:
                combined = torch.cat((x + t.unsqueeze(-1).unsqueeze(-1), hx + t.unsqueeze(-1).unsqueeze(-1)), 1)
            gates = self.conv(combined)  # gates: S, num_features*4, H, W
            # it should return 4 tensors: i,f,g,o
            ingate, forgetgate, cellgate, outgate = torch.split(
                gates, self.num_features, dim=1)
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            cellgate = torch.tanh(cellgate)
            outgate = torch.sigmoid(outgate)

            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
            output_inner.append(hy)
            hx = hy
            cx = cy
        return torch.stack(output_inner), (hy.detach(), cy.detach())


