from re import X
from turtle import forward
import torch.nn as nn
import torch
from timm.models.layers import to_2tuple
from typing import Optional
from .positional_encodings import rope3, rope2, RelativePositionalBias, rope3_maskflatten
from .utils import DropPath, window_partition, window_reverse, ScaleOffset, attn_norm, Linear
import torch.nn.functional as F
from torchvision import utils as vutils
from .moe_utils import TaskMoE, router_z_loss_func, load_balancing_loss_func, Top1Router
import copy
# from megatron_utils.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
# from megatron_utils import mpu
# import megatron_utils.utils
from timm.models.layers import trunc_normal_
import numpy as np
try:
    from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
except:
    flash_attn_qkvpacked_func = None
    flash_attn_func = None



class Cross_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5


        
        if len(self.window_size) == 2:
            self.position_enc = rope2(self.window_size, self.head_dim)
        elif len(self.window_size) == 3:
            self.position_enc = rope3(self.window_size, self.head_dim)


        self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
        self.l_q = nn.Linear(self.dim, self.dim, bias=qkv_bias)
        self.l_kv = nn.Linear(self.dim, self.dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.l_proj = nn.Linear(self.dim, self.dim)
        self.proj_drop = nn.Dropout(proj_drop)
        

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, y):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

    
        y_ = y.permute(0, 3, 1, 2)
        y_ = self.sr(y_).reshape(B, C, -1).permute(0, 2, 1)
        kv = self.l_kv(y_).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1] 

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.dim)
        x = self.l_proj(x)
        x = self.proj_drop(x)
        return x





class Conv_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = dim // num_heads // 4
        self.scale = self.head_dim ** -0.5

        if len(self.window_size) == 2:
            self.position_enc = rope2(self.window_size, self.head_dim)
        elif len(self.window_size) == 3:
            self.position_enc = rope3(self.window_size, self.head_dim)
        
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)
    
    def create_mask(self, x, shift_size):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition
        
        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -shift_size[0]),
                        slice(-shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -shift_size[1]),
                        slice(-shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -shift_size[0]),
                        slice(-shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        if len(self.window_size) == 3:
            mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        elif len(self.window_size) == 2:
            mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask




    def forward(self, x):
        T = 1
        if len(self.window_size) == 2:
            B, H, W, C = x.shape
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape
        
        qkv = self.qkv(x)                    #B, H, W, C*3
        qkv_list = qkv.chunk(4, dim=-1)      #B, H, W, 3C//4
        x_all = []
        for i in range(len(qkv_list)):
            if i == 0:
                shift_size = [0, 0]
            elif i == 1:
                shift_size = [0, self.window_size[1]//2]
            elif i == 2:
                shift_size = [self.window_size[0]//2, 0]
            else:
                shift_size = [self.window_size[0]//2, self.window_size[1]//2]
            
            if shift_size[0] > 0 or shift_size[1] > 0:
                if len(self.window_size) == 3:
                    shifted_qkv = torch.roll(qkv_list[i], shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
                elif len(self.window_size) == 2:
                    shifted_qkv = torch.roll(qkv_list[i], shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
                
                mask = self.create_mask(shifted_qkv, shift_size)
            else:
                shifted_qkv = qkv_list[i]
                mask = None
            
            N = 1
            for j in self.window_size:
                N *= j
            
            qkv_windows = window_partition(shifted_qkv, self.window_size).reshape(-1, N, 3*C//4)  # [B, nW, Mt, Mh, Mw, C]

            B_ = qkv_windows.shape[0]
            qkv_windows = qkv_windows.reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
            q, k, v = qkv_windows.unbind(0)
            q = self.position_enc(q.reshape(-1, *self.window_size, self.head_dim)).reshape(B_, self.num_heads, -1, self.head_dim)
            k = self.position_enc(k.reshape(-1, *self.window_size, self.head_dim)).reshape(B_, self.num_heads, -1, self.head_dim)

            # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))
            if mask is not None:
                # mask: [nW, Mh*Mw, Mh*Mw]
                nW = mask.shape[0]  # num_windows
                # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
                # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
            else:
                attn = self.softmax(attn)

            attn = self.attn_drop(attn)

            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
            # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
            # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
            attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C//4)

            # merge windows
            if len(self.window_size) == 3:
                attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C//4)  # [nW*B, Mh, Mw, C]
            elif len(self.window_size) == 2:
                attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C//4)  # [nW*B, Mh, Mw, C]
            shifted_x = window_reverse(attn_windows, self.window_size, T, H, W)  # [B, H', W', C]

            # reverse cyclic shift
            if shift_size[0] > 0 or shift_size[1] > 0:
                if len(self.window_size) == 3:
                    x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
                elif len(self.window_size) == 2:
                    x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
            else:
                x = shifted_x
            x_all.append(x)
        
        x = torch.cat(x_all, dim=-1)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
            




class Dilated_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., dilated_size=[1,1,1]) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        


        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)
       
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)
    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            B, H, W, C = x.shape
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x_windows = window_partition(x, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        x_windows = x_windows.reshape(-1, *self.total_window_size, C)
        B_ = x_windows.shape[0]
        if len(self.dilated_size) == 3:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, C)
        elif len(self.dilated_size) == 2:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], C).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1, C)
        B__, N, C = x_windows.shape


        qkv = self.qkv(x_windows).reshape(B__, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B__, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B__, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B__ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        attn_windows = (attn @ v).transpose(1, 2).reshape(B__, N, C)

        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B_, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B_, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        attn_windows = window_reverse(attn_windows, self.total_window_size, T, H, W)

        x = self.proj(attn_windows)
        x = self.proj_drop(x)
        return x


class Swin_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., shift_size=[0,0,0]) -> None:
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.shift_size = shift_size


        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)

        self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)
    
    def create_mask(self, x):
        pass

    def forward(self, x):
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            B, H, W, C = x.shape
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x)


        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x=x
            mask = None



        x_windows = window_partition(shifted_x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
        if len(self.window_size) == 3:
            x = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)  # [nW*B, Mh*Mw, C]
        elif len(self.window_size) == 2:
            x = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)  # [nW*B, Mh*Mw, C]
        B_, N, C = x.shape


        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)


        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        # relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        #     self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        # attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        # attn_save = torch.mean(attn, dim=1).to(torch.device('cpu')).reshape(-1, 32, 64)
        # for i in range(attn_save.shape[0]):
        #     save_h = i // 64
        #     save_w = i % 64
        #     vutils.save_image(attn_save[i]/attn_save[i].max(), "./img/attn%d_%d.png"%(save_h,save_w))


        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        if self.local is not True:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(-1, self.window_size[0] * self.window_size[1], C)
        shifted_x = window_reverse(attn_windows, self.window_size, T, H, W).reshape(B, -1, C)


        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x


        x = self.proj(x)
        x = self.proj_drop(x)
        return x




class SD_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., shift_size=[0, 0, 0], dilated_size=[1,1,1]) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = torch.tensor(head_dim ** -0.5)
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.shift_size = shift_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        


        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)
       
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    
    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            _, H, W, C = x.shape
        elif len(self.window_size) == 3:
            _, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.total_window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x=x
            mask = None

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x_windows = window_partition(shifted_x, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        x_windows = x_windows.reshape(-1, *self.total_window_size, C)
        B = x_windows.shape[0]
        if len(self.dilated_size) == 3:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, C)
        elif len(self.dilated_size) == 2:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, C)
        B_, N, C = x_windows.shape


        qkv = self.qkv(x_windows).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale.to(q)
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]

        save_attn = attn.mean(dim=-3)

        attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        shifted_x = window_reverse(attn_windows, self.total_window_size, T, H, W)

        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x

        x = self.proj(x)
        x = self.proj_drop(x)
        return x, save_attn


def spo_softmax(data):
    row_max = torch.max(data, dim=-1, keepdim=True).values.detach()
    data = data - row_max
    data_exp = torch.exp(data)
    data_softmax = data_exp / (torch.exp(0-row_max) + torch.sum(data_exp, dim=-1, keepdim=True))
    return data_softmax



class SPO_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., shift_size=[0, 0, 0], dilated_size=[1,1,1]) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.shift_size = shift_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        


        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)
       
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    
    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            _, H, W, C = x.shape
        elif len(self.window_size) == 3:
            _, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.total_window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x)

        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x=x
            mask = None

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x_windows = window_partition(shifted_x, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        x_windows = x_windows.reshape(-1, *self.total_window_size, C)
        B = x_windows.shape[0]
        if len(self.dilated_size) == 3:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, C)
        elif len(self.dilated_size) == 2:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, C)
        B_, N, C = x_windows.shape


        qkv = self.qkv(x_windows).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            # attn = self.softmax(attn)
            attn = spo_softmax(attn)
        else:
            # attn = self.softmax(attn)
            attn = spo_softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]

        # save_attn = attn.mean(dim=-3)

        attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        shifted_x = window_reverse(attn_windows, self.total_window_size, T, H, W)

        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x

        x = self.proj(x)
        x = self.proj_drop(x)
        return x




class Flash_attn(nn.Module):
    def __init__(self, dim, window_size, uv_bias=True, attn_drop=0., proj_drop=0., expansion_factor=2, attn_type='lin') -> None:
        super().__init__()
        self.attn_type = attn_type
        self.dim = dim
        self.window_size = window_size
        self.hidden_dim = expansion_factor * dim
        self.s = 128

        seq_len = 1
        for i in window_size:
            seq_len *= i

        self.scale = 1. / seq_len

        self.uv = nn.Linear(dim, 2*self.hidden_dim+self.s, bias=uv_bias)
        self.quad_q_scaleoffset = ScaleOffset(self.s)
        self.quad_k_scaleoffset = ScaleOffset(self.s)
        self.quad_attn_drop = nn.Dropout(attn_drop)

        if self.attn_type == "lin":
            self.lin_q_scaleoffset = ScaleOffset(self.s)
            self.lin_k_scaleoffset = ScaleOffset(self.s)
            self.rope_lin = rope2((32, 64), self.s)


        self.proj = nn.Linear(self.hidden_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        # self.softmax = nn.Softmax(dim=-1)
        self.attn_norm = attn_norm(dim=-1, method='squared_relu')

        
        self.rel_postion_bias = RelativePositionalBias(window_size, 1)
        # self.rope_quad = PositionalEncoding3D(self.s)
        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, self.s)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, self.s)
    

        # nn.init.normal_(self.relative_position_bias_table, std=.02)
    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            B, H, W, C = x.shape
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape
        

        x_windows = window_partition(x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
        if len(self.window_size) == 3:
            x = x_windows.view(B, -1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)  # [B, nW, Mt*Mh*Mw, C]
        elif len(self.window_size) == 2:
            x = x_windows.view(B, -1, self.window_size[0] * self.window_size[1], C)  # [B, nW, Mt*Mh*Mw, C]

        B, nW, N, C = x.shape 
        x = x.view(-1, N, C)
        B_ = x.shape[0]
        # u,v:[batch_size, num_windows, Mh*Mw, hidden_dim], base:[batch_size, num_windows, Mh*Mw, s]
        u, v, base = torch.split(F.silu(self.uv(x)), [self.hidden_dim, self.hidden_dim, self.s], dim=-1)
        # quad_q, quad_k: [batch_size, num_windows, Mh*Mw, s]
        quad_q, quad_k = self.quad_q_scaleoffset(base), self.quad_k_scaleoffset(base)

        quad_q = self.rope_quad(quad_q.reshape(-1, *self.window_size, self.s)).reshape(B_, N, self.s)
        quad_k = self.rope_quad(quad_k.reshape(-1, *self.window_size, self.s)).reshape(B_, N, self.s)

        if self.attn_type == 'lin':
            lin_q, lin_k = self.lin_q_scaleoffset(base), self.lin_k_scaleoffset(base)
            lin_q = window_reverse(lin_q, self.window_size, T, H, W)
            lin_q = self.rope_lin(lin_q)
            lin_q = window_partition(lin_q, self.window_size).reshape(B_, N, self.s)
            lin_k = window_reverse(lin_k, self.window_size, T, H, W)
            lin_k = self.rope_lin(lin_k)
            lin_k = window_partition(lin_k, self.window_size).reshape(B_, N, self.s)
            
            # lin_q = lin_q / lin_q.norm(dim=-1, keepdim=True)
            # lin_k = lin_k / lin_k.norm(dim=-1, keepdim=True)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # quad_attn: [batch_size, num_windows, Mh*Mw, Mh*Mw]
        quad_q = quad_q * self.scale
        quad_attn = quad_q @ quad_k.transpose(-2, -1)
        quad_attn = self.rel_postion_bias(quad_attn)

        if mask is not None:
            # mask: [B, nW, Mh*Mw]
            B, nW, _ = mask.shape  # num_windows
            # attn.view: [batch_size, num_windows, Mh*Mw, Mh*Mw]
            quad_attn_mask = mask.view(B, nW, 1, -1)
            attn_mask = torch.zeros_like(quad_attn_mask, dtype=quad_q.dtype)
            attn_mask = attn_mask.masked_fill(quad_attn_mask, float("-inf"))
            quad_attn = quad_attn + attn_mask
            quad_attn = self.attn_norm(quad_attn)
        else:
            quad_attn = self.attn_norm(quad_attn)
        # quad_attn:[batch_size, num_windows, Mh*Mw, Mh*Mw]
        quad_attn = self.quad_attn_drop(quad_attn)

        # quadratic: [batch_size, num_windows, Mh*Mw, hidden_dim]
        quadratic = quad_attn @ v
        # if self.train:
        #     if self.quad_attn_scale is None:
        #         self.quad_attn_scale = 0.2 / (quadratic.abs().mean().detach()+1e-7)
        #     else:
        #         self.quad_attn_scale = self.quad_attn_scale * self.beta + (1-self.beta) * (0.2 / (quadratic.abs().mean().detach() + 1e-7))
        # quadratic = quadratic * self.quad_attn_scale

        if self.attn_type == 'lin':
            if mask is not None:
                # lin_mask: [B, nW, Mh*Mw, 1]
                lin_mask = torch.logical_not(mask).unsqueeze(-1)
                # lin_v: [B, nW, Mh*Mw, hidden_dim]
                lin_v = lin_mask * v / (N * nW * self.s)
            else:
                lin_v = v / (N * nW)
            
            # lin_kv: [B, nW, s, hidden_dim]
            lin_kv = lin_k.transpose(-2, -1) @ lin_v
            # linear: [B, nW, Mh*Mw, hidden_dim]
            linear = lin_q @ torch.sum(lin_kv, dim=-3, keepdim=True)

        # @: multiply -> [batch_size*num_windows, Mh*Mw, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        if self.attn_type == 'lin':
            x = u * (quadratic + linear)
        else:
            x = u * quadratic
        x = self.proj(x)
        x = self.proj_drop(x)

        # merge windows
        if len(self.window_size) == 3:
            attn_windows = x.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C)  # [B, nW, Mt, Mh, Mw, C]
        elif len(self.window_size) == 2:
            attn_windows = x.view(-1, self.window_size[0], self.window_size[1], C)  # [B, nW, Mh, Mw, C]
        x = window_reverse(attn_windows, self.window_size, T, H, W)  # [B, T, H, W, C]

        return x

class Hydra_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., expansion_factor=1, local=True, use_attn=True) -> None:
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.hidden_dim = expansion_factor * dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.use_attn = use_attn
        self.local = local
        # self.scale = 1
        
        if self.use_attn:
            self.qkv = nn.Linear(dim, 3*self.dim, bias=qkv_bias)
        else:
            self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, 2*dim, bias=qkv_bias)

        # if self.attn_type == "lin":
        #     self.lin_q_scaleoffset = ScaleOffset(self.s)
        #     self.lin_k_scaleoffset = ScaleOffset(self.s)
        #     # self.rope_lin = PositionalEncoding3D(dim)
        #     self.rope_lin = rope3((16, 32, 64), self.s)


        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        # self.attn_norm = attn_norm(dim=-1, method='squared_relu')

        

        # self.rope_quad = PositionalEncoding3D(self.s)
        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)

    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [B, T, H, W, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            B, H, W, C = x.shape
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape
        
        hy_k, hy_v = torch.split(self.kv(x), [self.dim, self.dim], dim=-1)

        # hy_k = hy_k.view(B, -1, C)
        # hy_v = hy_v.view(B, -1, C)

        # hy_k = hy_k / hy_k.norm(dim=-2, keepdim=True)
        # hy_v = hy_v / hy_v.norm(dim=-2, keepdim=True)
        hy_k = hy_k / hy_k.norm(dim=-1, keepdim=True)
        hy_kv = (hy_k * hy_v).reshape(B, -1, C)
        hy_kv = hy_kv.sum(dim=-2, keepdim=True)

        if self.use_attn:

            x_windows = window_partition(x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
            if self.local:
                if len(self.window_size) == 3:
                    x = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)  # [nW*B, Mh*Mw, C]
                elif len(self.window_size) == 2:
                    x = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)  # [nW*B, Mh*Mw, C]
            else:
                x = x_windows.view(B, -1, self.window_size[0] * self.window_size[1], C).permute(0, 2, 1, 3).reshape(B*self.window_size[0] * self.window_size[1], -1, C)
            B_, N, C = x.shape


            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
            q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

            # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]

            q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)


            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))

            # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
            # relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            #     self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
            # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
            # attn = attn + relative_position_bias.unsqueeze(0)

            if mask is not None:
                # mask: [nW, Mh*Mw, Mh*Mw]
                nW = mask.shape[0]  # num_windows
                # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
                # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
            else:
                attn = self.softmax(attn)

            attn = self.attn_drop(attn)
            # attn_save = torch.mean(attn, dim=1).to(torch.device('cpu')).reshape(-1, 32, 64)
            # for i in range(attn_save.shape[0]):
            #     save_h = i // 64
            #     save_w = i % 64
            #     vutils.save_image(attn_save[i]/attn_save[i].max(), "./img/attn%d_%d.png"%(save_h,save_w))


            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
            # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
            # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
            attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)
            if self.local is not True:
                attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(-1, self.window_size[0] * self.window_size[1], C)
            attn_x = window_reverse(attn_windows, self.window_size, T, H, W).reshape(B, -1, C)

            
        else:
            hy_q = self.q(x).reshape(B, -1, C)
            attn_x = hy_q / hy_q.norm(dim=-1, keepdim=True)
        # x = attn_x * hy_kv
        if self.use_attn:
            x = attn_x
        else:
            x = attn_x * hy_kv

        x = self.proj(x)
        x = self.proj_drop(x).reshape(B, T, H, W, C).squeeze(1)


        return x


class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.) -> None:
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        assert len(window_size) == 3

        self.position_enc = rope3(window_size, head_dim)
       
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class HiLo(nn.Module):
    """
    HiLo Attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim/num_heads)
        self.dim = dim

        # self-attention heads in Lo-Fi
        self.l_heads = int(num_heads * alpha)
        # token dimension in Lo-Fi
        self.l_dim = self.l_heads * head_dim

        # self-attention heads in Hi-Fi
        self.h_heads = num_heads - self.l_heads
        # token dimension in Hi-Fi
        self.h_dim = self.h_heads * head_dim

        # local window size. The `s` in our paper.
        self.ws = window_size

        if (self.ws[0] == 1) and (self.ws[1] == 1):
            # ws == 1 is equal to a standard multi-head self-attention
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        # Low frequence attention (Lo-Fi)
        if self.l_heads > 0:
            if (self.ws[0] != 1) or (self.ws[1] != 1):
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        # High frequence attention (Hi-Fi)
        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

        # self.position_enc = rope2(window_size, head_dim)

    def hifi(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws[0], W // self.ws[1]

        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws[0], w_group, self.ws[1], C).transpose(2, 3)

        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dim
        # q = self.position_enc(q.reshape(-1, self.ws[0], self.ws[1], self.h_dim // self.h_heads)).reshape(*q.shape)
        # k = self.position_enc(k.reshape(-1, self.ws[0], self.ws[1], self.h_dim // self.h_heads)).reshape(*k.shape)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws[0], self.ws[1], self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws[0], w_group * self.ws[1], self.h_dim)

        x = self.h_proj(x)
        return x

    def lofi(self, x):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws[0] > 1 or self.ws[1] > 1:
            x_ = x.permute(0, 3, 1, 2)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1] 

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x):
        B, H, W, C = x.shape

        if self.h_heads == 0:
            x = self.lofi(x)
            return x

        if self.l_heads == 0:
            x = self.hifi(x)
            return x

        hifi_out = self.hifi(x)
        lofi_out = self.lofi(x)

        x = torch.cat((hifi_out, lofi_out), dim=-1)
        return x



class SD_attn_withmoe(nn.Module):
    def __init__(self, dim, attr_len, attr_hidden_size, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., shift_size=[0, 0, 0], dilated_size=[1,1,1],
                num_experts=1, expert_capacity=1., router_bias=True, router_noise=1e-2, is_scale_prob=True, drop_tokens=True) -> None:
        super().__init__()
        self.dim = dim
        self.attr_len = attr_len
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.expert_capacity = expert_capacity
        self.is_scale_prob = is_scale_prob
        self.drop_tokens = drop_tokens
        self.n_experts = num_experts
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.shift_size = shift_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        


        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)
       
        qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.gate = Top1Router(attr_len, 
                            attr_hidden_size, 
                            num_experts, 
                            router_bias=router_bias, 
                            router_noise=router_noise)
        self.qkv = torch.nn.ModuleList(
            [copy.deepcopy(qkv) for i in range(num_experts)])
        

        self.attn_drop = nn.Dropout(attn_drop)

        proj = nn.Linear(dim, dim)

        self.proj = torch.nn.ModuleList(
            [copy.deepcopy(proj) for i in range(num_experts)])

        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    
    def forward(self, x, attr=None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            Bs, H, W, C = x.shape
        elif len(self.window_size) == 3:
            Bs, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.total_window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x)

        if self.attr_len > self.dim and attr is not None:
            expert_index, router_probs, router_logits = self.gate(torch.cat((x, attr),dim=-1))
        elif attr is not None:
            expert_index, router_probs, router_logits = self.gate(attr)
        else:
            expert_index, router_probs, router_logits = self.gate(x)


        route_prob_max, routes = torch.max(router_probs, dim=-1)

        # Get indexes of tokens going to each expert
        indexes_list = [torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)]

        # Initialize an empty tensor to store outputs
        x = x.reshape(-1, C)
        qkv = x.new_zeros(Bs*H*W, 3*C)

        # Capacity of each expert.
        # $$\mathrm{expert\;capacity} =
        # \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
        # \times \mathrm{capacity\;factor}$$
        capacity = int(self.expert_capacity * Bs * H * W / self.n_experts)

        # Initialize an empty list of dropped tokens
        dropped = []
        # Only drop tokens if `drop_tokens` is `True`.
        if self.drop_tokens and self.training:
            # Drop tokens in each of the experts
            for i in range(self.n_experts):
                # Ignore if the expert is not over capacity
                if len(indexes_list[i]) <= capacity:
                    continue
                # Shuffle indexes before dropping
                indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
                # Collect the tokens over capacity as dropped tokens
                dropped.append(indexes_list[i][capacity:])
                # Keep only the tokens upto the capacity of the expert
                indexes_list[i] = indexes_list[i][:capacity]

        # Get outputs of the expert FFNs
        # expert_output = [self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)]
        qkv_output = [self.qkv[i](x[indexes_list[i], :]) for i in range(self.n_experts)]

        # Assign to final output
        for i in range(self.n_experts):
            qkv[indexes_list[i], :] = qkv_output[i]
        if dropped:
            dropped_tensor = torch.cat(dropped)
            qkv[dropped_tensor, :] = torch.zeros_like(qkv[dropped_tensor, :], device=qkv.device)


        qkv = qkv.reshape(Bs, H, W, 3*C)

        moe_mask = qkv.new_zeros(Bs * H * W, 1)
        if dropped:
            moe_mask[dropped_tensor, :] = -torch.inf
        moe_mask = moe_mask.reshape(Bs, H, W, 1)

        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_qkv = torch.roll(qkv, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
                shifted_moe_mask = torch.roll(moe_mask, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_qkv = torch.roll(qkv, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
                shifted_moe_mask = torch.roll(moe_mask, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_qkv = qkv
            shifted_moe_mask = moe_mask
            mask = None

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        qkv_windows = window_partition(shifted_qkv, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        moe_mask_windows = window_partition(shifted_moe_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]

        qkv_windows = qkv_windows.reshape(-1, *self.total_window_size, 3*C)
        moe_mask_windows = moe_mask_windows.reshape(-1, *self.total_window_size, 1)
        B = qkv_windows.shape[0]
        if len(self.dilated_size) == 3:
            qkv_windows = window_partition(qkv_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 3*C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, 3*C)
            moe_mask_windows = window_partition(moe_mask_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, 1)
        elif len(self.dilated_size) == 2:
            qkv_windows = window_partition(qkv_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 3*C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, 3*C)
            moe_mask_windows = window_partition(moe_mask_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, 1)


        B_, N, _ = qkv_windows.shape

        qkv = qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        moe_mask_windows = moe_mask_windows.view(B_, 1, 1, N).expand(-1, self.num_heads, N, -1)


        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N) + moe_mask_windows
            attn = self.softmax(attn)
        else:
            attn = attn.view(-1, self.num_heads, N, N) + moe_mask_windows
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)


        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        shifted_x = window_reverse(attn_windows, self.total_window_size, T, H, W)

        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x_out = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x_out = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x_out = shifted_x


        x_out = x_out.reshape(-1, C)
        final_output = x_out.new_zeros(Bs*H*W, C)
        expert_output = [self.proj[i](x_out[indexes_list[i], :]) for i in range(self.n_experts)]
        for i in range(self.n_experts):
            final_output[indexes_list[i], :] = expert_output[i]
        if dropped:
            final_output[dropped_tensor, :] = x[dropped_tensor, :]



        if self.is_scale_prob:
            # Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
            final_output = final_output * route_prob_max.view(-1, 1)
        else:
            # Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
            # (this is something we experimented with).
            final_output = final_output * (route_prob_max / route_prob_max.detach()).view(-1, 1)
        x = final_output.reshape(Bs, H, W, C)

        x = self.proj_drop(x)
        z_loss = router_z_loss_func(router_logits=router_logits)
        balance_loss = load_balancing_loss_func(router_probs=router_probs, expert_indices=expert_index)

        return x, z_loss, balance_loss



class SD_attn_parallel(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., 
                 proj_drop=0., shift_size=[0, 0, 0], dilated_size=[1,1,1], use_cpu_initialization=True) -> None:
        super().__init__()
        self.dim = dim
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.shift_size = shift_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        world_size = mpu.get_tensor_model_parallel_world_size()
        self.num_attention_heads_per_partition = megatron_utils.utils.divide(
            num_heads, world_size)

        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)
       

        self.qkv = ColumnParallelLinear(dim, dim*3, bias=qkv_bias, gather_output=False, 
                                        async_tensor_model_parallel_allreduce=False,
                                        use_cpu_initialization=use_cpu_initialization)

        # self.attn_drop = nn.Dropout(attn_drop)

        self.proj = RowParallelLinear(dim, dim, input_is_parallel=True,
                                      use_cpu_initialization=use_cpu_initialization)
        # self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    
    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            Bs, H, W, C = x.shape
        elif len(self.window_size) == 3:
            Bs, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.total_window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x)

        # Initialize an empty tensor to store outputs
        x = x.reshape(Bs, -1, C)
        qkv, _ = self.qkv(x)
        qkv = qkv.reshape(Bs, H, W, -1)
        Bs, H, W, qkv_C = qkv.shape


        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_qkv = torch.roll(qkv, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_qkv = torch.roll(qkv, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_qkv = qkv
            mask = None

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        qkv_windows = window_partition(shifted_qkv, self.total_window_size)  # [B * nW, Mt, Mh, Mw, C]
        # qkv_windows = qkv_windows.reshape(-1, *self.total_window_size, 3*C // self.world_size)
        B = qkv_windows.shape[0]
        if len(self.dilated_size) == 3:
            qkv_windows = window_partition(qkv_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], qkv_C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, qkv_C)
        elif len(self.dilated_size) == 2:
            qkv_windows = window_partition(qkv_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], qkv_C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, qkv_C)


        B_, N, _ = qkv_windows.shape
        qkv = qkv.reshape(B_, N, 3, self.num_attention_heads_per_partition, qkv_C // self.num_attention_heads_per_partition // 3).permute(2, 0, 3, 1, 4)


        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(-1, *self.window_size, qkv_C // self.num_attention_heads_per_partition//3)).reshape(B_, \
                                    self.num_attention_heads_per_partition, -1, qkv_C // self.num_attention_heads_per_partition//3)
        k = self.position_enc(k.reshape(-1, *self.window_size, qkv_C // self.num_attention_heads_per_partition//3)).reshape(B_, \
                                self.num_attention_heads_per_partition, -1, qkv_C // self.num_attention_heads_per_partition//3)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))


        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_attention_heads_per_partition, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_attention_heads_per_partition, N, N)
            attn = self.softmax(attn)
        else:
            attn = attn.view(-1, self.num_attention_heads_per_partition, N, N)
            attn = self.softmax(attn)

        # attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, qkv_C//3)


        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B, -1, N, qkv_C//3).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], qkv_C//3)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B, -1, N, qkv_C//3).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], qkv_C//3)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        shifted_x = window_reverse(attn_windows, self.total_window_size, T, H, W)

        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x
        
        Bs, H, W, _ = x.shape
        x = x.reshape(Bs, H*W, -1)
        x, _ = self.proj(x)
        x = x.reshape(Bs, H, W, -1)
        return x


def get_activation_fn(activation):
    if activation == "swish":
        return F.silu
    elif activation == "gelu":
        return F.gelu
    else:
        raise NotImplementedError

class Rentention_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, factor=1, qkv_bias=True, attn_drop=0., proj_drop=0., shift_size=[0, 0, 0], dilated_size=[1,1,1]) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.factor = factor
        
        self.dilated_size = dilated_size[-len(window_size):]
        self.window_size = window_size
        self.shift_size = shift_size
        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        

        decay = torch.log(1 - 2 ** (-5 - torch.arange(num_heads, dtype=torch.float)))
        self.register_buffer("decay", decay)


        if len(self.window_size) == 2:
            self.rope_quad = rope2(self.window_size, head_dim)
        elif len(self.window_size) == 3:
            self.rope_quad = rope3(self.window_size, head_dim)

        self.q_proj = nn.Linear(dim, dim, bias=True)
        self.k_proj = nn.Linear(dim, dim, bias=True)
        self.v_proj = nn.Linear(dim, dim * self.factor, bias=True)
        self.g_proj = nn.Linear(dim, dim * self.factor, bias=True)

        self.gate_fn = get_activation_fn(activation=str("swish"))
        self.group_norm = nn.LayerNorm(self.factor * head_dim, eps=1e-6, elementwise_affine=False)

        # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(self.factor * dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # self.softmax = nn.Softmax(dim=-1)

        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(0.)).masked_fill(attn_mask == 0, float(1.0))
        return attn_mask

    
    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            _, H, W, C = x.shape
        elif len(self.window_size) == 3:
            _, T, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.total_window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x)

        if self.shift_size[-1] > 0:
            if len(self.window_size) == 3:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
        else:
            shifted_x=x
            mask = None

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x_windows = window_partition(shifted_x, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        x_windows = x_windows.reshape(-1, *self.total_window_size, C)
        B = x_windows.shape[0]
        if len(self.dilated_size) == 3:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1, C)
        elif len(self.dilated_size) == 2:
            x_windows = window_partition(x_windows, self.dilated_size).reshape(B, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], C).permute(
                                        0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, C)
        B_, N, C = x_windows.shape


        q = self.q_proj(x_windows).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k_proj(x_windows).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v_proj(x_windows).reshape(B_, N, self.num_heads, self.factor * C // self.num_heads).permute(0, 2, 1, 3)
        g = self.g_proj(x_windows).reshape(B_, N, self.num_heads, self.factor * C // self.num_heads).permute(0, 2, 1, 3)

        # qkv = self.qkv(x_windows).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
        
        k = k * self.scale

        q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
        k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        # q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        x_index = torch.arange(self.window_size[0]).to(self.decay)
        y_index = torch.arange(self.window_size[1]).to(self.decay)
        x_mask = torch.abs(x_index[:, None] - x_index[None, :])
        y_mask = torch.abs(y_index[:, None] - y_index[None, :])
        x_mask = torch.exp(x_mask * self.decay[:, None, None])
        y_mask = torch.exp(y_mask * self.decay[:, None, None])
        decay_mask = x_mask[:, None, :, None, :] * y_mask[:, :, None, :, None]
        decay_mask = decay_mask.flatten(1, 2).flatten(-2, -1)
        
        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            mask = mask.unsqueeze(1) * decay_mask.unsqueeze(0)
            mask = torch.nan_to_num(mask)
            mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) * mask
            attn = attn / attn.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)
            attn = attn.view(-1, self.num_heads, N, N)
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
        else:
            mask = decay_mask
            attn = attn * mask
            attn = attn / attn.detach().sum(dim=-1, keepdim=True).abs().clamp(min=1)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        attn_windows = (attn @ v)
        attn_windows = self.group_norm(attn_windows)

        attn_windows = self.gate_fn(g) * attn_windows
        attn_windows = attn_windows.transpose(1, 2).reshape(B_, N, self.factor * C)
        # attn_windows = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        if len(self.window_size) == 3:
            attn_windows = attn_windows.reshape(B, -1, N, self.factor * C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], self.factor * C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, *self.total_window_size)
        elif len(self.window_size) == 2:
            attn_windows = attn_windows.reshape(B, -1, N, self.factor * C).permute(0, 2, 1, 3).reshape(
                                            -1, self.dilated_size[0]*self.dilated_size[1], self.factor * C)
            attn_windows = window_reverse(attn_windows, self.dilated_size, 1, *self.total_window_size)
            
        shifted_x = window_reverse(attn_windows, self.total_window_size, T, H, W)

        if self.shift_size[0] > 0:
            if len(self.window_size) == 3:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))
            elif len(self.window_size) == 2:
                x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
        else:
            x = shifted_x

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MHA_withflash(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.window_size = window_size


        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = torch.tensor(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)


        if len(window_size) == 2:
            self.position_enc = rope2(window_size, head_dim)
        elif len(window_size) == 3:
            self.position_enc = rope3(window_size, head_dim)


    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        T=1

        if len(self.window_size) == 2:
            B, H, W, C = x.shape
            T = 1
        elif len(self.window_size) == 3:
            B, T, H, W, C = x.shape

        x_windows = x.reshape(B, T * H * W, C)

        B_, N, C = x_windows.shape


        qkv = self.qkv(x_windows).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        q = self.position_enc(q.reshape(B*self.num_heads, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.position_enc(k.reshape(B*self.num_heads, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads).permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        x = flash_attn_func(q, k, v, self.attn_drop, softmax_scale=self.scale)
        # x = x.type(dtype=data_type)
        x = x.reshape(B, *self.window_size, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Swin_Attention_v2(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], shift_size=[0, 0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="swinv2"):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.use_qknorm = use_qknorm

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)        
            self.qkv = nn.Linear(dim, dim * 3, bias=False)
            if qkv_bias:
                self.q_bias = nn.Parameter(torch.zeros(dim))
                self.v_bias = nn.Parameter(torch.zeros(dim))
            else:
                self.q_bias = None
                self.v_bias = None
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        if posembed_type == "swinv2":
            # mlp to generate continuous relative position bias
            self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(512, num_heads, bias=False))

            # get relative_coords_table
            relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
            relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
            relative_coords_table = torch.stack(
                torch.meshgrid([relative_coords_h,
                                relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
            if pretrained_window_size[0] > 0:
                relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
            else:
                relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
            relative_coords_table *= 8  # normalize to -8, 8
            relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
                torch.abs(relative_coords_table) + 1.0) / np.log2(8)

            self.register_buffer("relative_coords_table", relative_coords_table)

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        elif posembed_type == "rotaty":
            self.position_enc = rope2(window_size, head_dim)

        # self.qkv = nn.Linear(dim, dim * 3, bias=False)
        # if qkv_bias:
        #     self.q_bias = nn.Parameter(torch.zeros(dim))
        #     self.v_bias = nn.Parameter(torch.zeros(dim))
        # else:
        #     self.q_bias = None
        #     self.v_bias = None
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        self._init_weights()

    def _init_weights(self):
        
        trunc_normal_(self.qkv.weight, std=.02)
        if self.qkv.bias is not None:
            nn.init.constant_(self.qkv.bias, 0)

        trunc_normal_(self.proj.bias, std=.02)
        nn.init.constant_(self.proj.weight, 0)

    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        _, H, W, _ = x.shape
        img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size[0]),
                    slice(-self.window_size[0], -self.shift_size[0]),
                    slice(-self.shift_size[0], None))
        w_slices = (slice(0, -self.window_size[1]),
                    slice(-self.window_size[1], 0),
                    slice(0, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask



    def forward(self, x, x_mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        _, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x = window_partition(x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)
        B_, N, C = x.shape

        # qkv_bias = None
        # if self.q_bias is not None:
        #     qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]

        # cosine attention
        if self.use_qknorm:
            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
            qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            q = q * self.logit_scale.to(q)

        if self.posembed_type == "rotaty":
            q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            if mask is not None:
                attn = q @ k.transpose(-2, -1)
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
            elif self.use_flash:
                origin_type = q.dtype
                if origin_type == torch.float32:
                    q = q.to(torch.bfloat16)
                    k = k.to(torch.bfloat16)
                    v = v.to(torch.bfloat16)
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
                x = x.to(origin_type)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        # attn = attn * logit_scale
        elif self.posembed_type == "swinv2":
            relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
            relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

            attn = q @ k.transpose(-2, -1)
            attn = attn + relative_position_bias.to(attn).unsqueeze(0)
            if mask is not None:
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)

        else:
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # if mask is not None:
        #     nW = mask.shape[0]
        #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        #     attn = attn.view(-1, self.num_heads, N, N)
        #     attn = self.softmax(attn)
        # else:
        #     attn = self.softmax(attn)

        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)
        x = window_reverse(x, self.window_size, 1, H, W)

        if self.shift_size[0] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Vitattn_withflash(nn.Module):
    def __init__(self, dim, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., qk_norm=True, norm_layer=nn.LayerNorm) -> None:
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5



        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
        self.attn_drop = torch.tensor(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size, Mt, Mh, Mw, total_embed_dim]
        B, H, W, C = x.shape

        x_windows = x.reshape(B, H * W, C)

        B_, N, C = x_windows.shape


        qkv = self.qkv(x_windows).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
        q, k = self.q_norm(q), self.k_norm(k)
        x = flash_attn_func(q, k, v, self.attn_drop, softmax_scale=self.scale)
        # x = x.type(dtype=data_type)
        x = x.reshape(B, H, W, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x




class Airsea_CrossAttention(nn.Module):

    def __init__(self, dim, window_size1, window_size2, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="swinv2"):

        super().__init__()
        self.dim = dim
        self.window_size1 = window_size1  # Wh, Ww
        self.window_size2 = window_size2
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.use_qknorm = use_qknorm

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)

        if posembed_type == "rotaty":
            self.position_enc1 = rope2(window_size1, head_dim)
            self.position_enc2 = rope2(window_size2, head_dim)

        # self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.q = nn.Linear(dim, dim, bias=True if qkv_bias else False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=True if qkv_bias else False)
        
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)


        self._init_weights()

    def _init_weights(self):
        
        trunc_normal_(self.q.weight, std=.02)
        trunc_normal_(self.k.weight, std=.02)
        trunc_normal_(self.v.weight, std=.02)
        if self.q.bias is not None:
            nn.init.constant_(self.q.bias, 0)
        if self.k.bias is not None:
            nn.init.constant_(self.k.bias, 0)
        if self.v.bias is not None:
            nn.init.constant_(self.v.bias, 0)

        nn.init.constant_(self.proj.bias, 0)
        nn.init.constant_(self.proj.weight, 0)

 


    def forward(self, x1, x2):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        _, H1, W1, C = x1.shape
        _, H2, W2, C = x2.shape

        mask = None

        x1 = x1.reshape(-1, H1 * W1, C)
        B_, N1, C = x1.shape
        x2 = x2.reshape(-1, H2 * W2, C)
        B_, N2, C = x2.shape

        q = self.q(x1).reshape(B_, N1, self.num_heads, -1).permute(0, 2, 1, 3)
        k = self.k(x2).reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)
        v = self.v(x2).reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)

        # qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]

        # cosine attention
        if self.use_qknorm:
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            q = q * self.logit_scale.to(q)

        if self.posembed_type == "rotaty":
            q = self.position_enc1(q.reshape(-1, *self.window_size1, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc2(k.reshape(-1, *self.window_size2, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        # attn = attn * logit_scale

        else:
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        x = x.reshape(-1, H1, W1, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Lora_Attention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], shift_size=[0, 0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="swinv2", r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.use_qknorm = use_qknorm

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)        
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False)
            if qkv_bias:
                self.q_bias = nn.Parameter(torch.zeros(dim))
                self.v_bias = nn.Parameter(torch.zeros(dim))
                if r > 0:
                    self.q_bias.requires_grad = False
                    self.v_bias.requires_grad = False
            else:
                self.q_bias = None
                self.v_bias = None
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=qkv_bias)

        if posembed_type == "swinv2":
            # mlp to generate continuous relative position bias
            self.cpb_mlp = nn.Sequential(Linear(2, 512, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=True),
                                        nn.ReLU(inplace=True),
                                        Linear(512, num_heads, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False))

            # get relative_coords_table
            relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
            relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
            relative_coords_table = torch.stack(
                torch.meshgrid([relative_coords_h,
                                relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
            if pretrained_window_size[0] > 0:
                relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
            else:
                relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
            relative_coords_table *= 8  # normalize to -8, 8
            relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
                torch.abs(relative_coords_table) + 1.0) / np.log2(8)

            self.register_buffer("relative_coords_table", relative_coords_table)

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        elif posembed_type == "rotaty":
            self.position_enc = rope2(window_size, head_dim)

        
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = Linear(dim, dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def create_mask(self, x, mask=None):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        _, H, W, _ = x.shape
        img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size[0]),
                    slice(-self.window_size[0], -self.shift_size[0]),
                    slice(-self.shift_size[0], None))
        w_slices = (slice(0, -self.window_size[1]),
                    slice(-self.window_size[1], 0),
                    slice(0, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask



    def forward(self, x, x_mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        _, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x = window_partition(x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)
        B_, N, C = x.shape

        # qkv_bias = None
        # if self.q_bias is not None:
        #     qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = self.qkv(x)
        # qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]

        # cosine attention
        if self.use_qknorm:
            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
                qkv = self.qkv(x)
                qkv = qkv + qkv_bias
            else:
                qkv = self.qkv(x)
            # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            q = q * self.logit_scale.to(q)

        if self.posembed_type == "rotaty":
            q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            if mask is not None:
                attn = q @ k.transpose(-2, -1)
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
            elif self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        # attn = attn * logit_scale
        elif self.posembed_type == "swinv2":
            relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
            relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

            attn = q @ k.transpose(-2, -1)
            attn = attn + relative_position_bias.to(attn).unsqueeze(0)
            if mask is not None:
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)

        else:
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # if mask is not None:
        #     nW = mask.shape[0]
        #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        #     attn = attn.view(-1, self.num_heads, N, N)
        #     attn = self.softmax(attn)
        # else:
        #     attn = self.softmax(attn)

        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)
        x = window_reverse(x, self.window_size, 1, H, W)

        if self.shift_size[0] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    
def safe_softmax(input, dim=-1, epsilon=1e-9):
    """
    A softmax function that returns all zeros if all inputs along the specified dimension are infinity.
    Args:
        input (Tensor): input
        dim (int or list): A dimension or a list of dimensions along which softmax will be computed.
        epsilon (float): A small number to add to the denominator for numerical stability.
    """

    max_val = input.max(dim=dim, keepdim=True).values
    is_inf = torch.isinf(max_val)
    exps = torch.where(is_inf, torch.zeros_like(input), torch.exp(input - max_val))
    sum_exps = exps.sum(dim=dim, keepdim=True) + epsilon
    output = exps / sum_exps

    return output


class Lora_Attention_withmask(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], shift_size=[0, 0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="swinv2", r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.use_qknorm = use_qknorm

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)        
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False)
            if qkv_bias:
                self.q_bias = nn.Parameter(torch.zeros(dim))
                self.v_bias = nn.Parameter(torch.zeros(dim))
                if r > 0:
                    self.q_bias.requires_grad = False
                    self.v_bias.requires_grad = False
            else:
                self.q_bias = None
                self.v_bias = None
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=qkv_bias)

        if posembed_type == "swinv2":
            # mlp to generate continuous relative position bias
            self.cpb_mlp = nn.Sequential(Linear(2, 512, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=True),
                                        nn.ReLU(inplace=True),
                                        Linear(512, num_heads, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False))

            # get relative_coords_table
            relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
            relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
            relative_coords_table = torch.stack(
                torch.meshgrid([relative_coords_h,
                                relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
            if pretrained_window_size[0] > 0:
                relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
            else:
                relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
            relative_coords_table *= 8  # normalize to -8, 8
            relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
                torch.abs(relative_coords_table) + 1.0) / np.log2(8)

            self.register_buffer("relative_coords_table", relative_coords_table)

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        elif posembed_type == "rotaty":
            self.position_enc = rope2(window_size, head_dim)

        
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = Linear(dim, dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
        self.proj_drop = nn.Dropout(proj_drop)

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        _, H, W, _ = x.shape
        img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
        h_slices = (slice(0, -self.window_size[0]),
                    slice(-self.window_size[0], -self.shift_size[0]),
                    slice(-self.shift_size[0], None))
        w_slices = (slice(0, -self.window_size[1]),
                    slice(-self.window_size[1], 0),
                    slice(0, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask



    def forward(self, x, x_mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        B, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
            if x_mask is not None:
                x_mask = torch.roll(x_mask, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x = window_partition(x, self.window_size)  # [B, nW, Mt, Mh, Mw, C]
        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)         #[B_, N, C]
        if x_mask is not None:
            x_mask = window_partition(x_mask.unsqueeze(-1), self.window_size)  # [B, nW, Md, Mh, Mw, C]
            # print(x_mask.shape)
            x_mask = x_mask.reshape(-1, self.window_size[0] * self.window_size[1])        # [B_, N]
        
        if x_mask is not None:
            pre_x = x.clone()
            if mask is not None:
                mask = mask[x_mask[:mask.shape[0]].any(dim=-1).nonzero(as_tuple=True)].reshape(-1, self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1])
            x_mask_mask = x_mask.any(dim=-1).nonzero(as_tuple=True)
            # x_mask_indices = x_mask.nonzero(as_tuple=True)
            x_mask = x_mask[x_mask_mask].reshape(-1, self.window_size[0] * self.window_size[1])
            # print(x_mask.shape)
            x = pre_x[x_mask_mask].reshape(-1, self.window_size[0] * self.window_size[1], C)
            # print(pre_x.shape, x.shape)
            
        B_, N, C = x.shape

        # qkv_bias = None
        # if self.q_bias is not None:
        #     qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = self.qkv(x)
        # qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]

        # cosine attention
        if self.use_qknorm:
            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
                qkv = self.qkv(x)
                qkv = qkv + qkv_bias
            else:
                qkv = self.qkv(x)
            # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            q = q * self.logit_scale.to(q)
        if x_mask is not None:
            x_mask = x_mask.unsqueeze(-1).unsqueeze(-1).expand(B_, N, self.num_heads, 1).permute(0, 2, 1, 3).reshape(-1, N)

        if self.posembed_type == "rotaty":
            q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            if mask is not None:
                attn = q @ k.transpose(-2, -1)
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                if x_mask is not None and self.window_size != [H, W]:
                    attn = attn.masked_fill(x_mask.reshape(B_ // nW, nW, self.num_heads, N, 1) == 0, -torch.inf)
                    attn = attn.masked_fill(x_mask.reshape(B_ // nW, nW, self.num_heads, 1, N) == 0, -torch.inf)

                
                attn = attn.view(-1, self.num_heads, N, N)
                attn = safe_softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
            elif self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                if x_mask is not None and self.window_size != [H, W]:
                    # print(attn.shape)
                    # print(x_mask.reshape(B_, self.num_heads, N).shape)
                    attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, N, 1) == 0, -torch.inf)
                    attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, 1, N) == 0, -torch.inf)
                
                attn = safe_softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        # attn = attn * logit_scale
        elif self.posembed_type == "swinv2":
            relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
            relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

            attn = q @ k.transpose(-2, -1)
            attn = attn + relative_position_bias.to(attn).unsqueeze(0)
            if mask is not None:
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
            attn = safe_softmax(attn)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)

        else:
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = safe_softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # if mask is not None:
        #     nW = mask.shape[0]
        #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        #     attn = attn.view(-1, self.num_heads, N, N)
        #     attn = self.softmax(attn)
        # else:
        #     attn = self.softmax(attn)

        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = x.reshape(-1, self.window_size[0] * self.window_size[1], C)

        if x_mask is not None:
            pre_x[x_mask_mask] = x
            x = pre_x


        x = window_reverse(x, self.window_size, 1, H, W)

        if self.shift_size[0] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


    
class Lora_Attention3d(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
        pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], shift_size=[0, 0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="rotaty", r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.use_qknorm = use_qknorm

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)        
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False)
            if qkv_bias:
                self.q_bias = nn.Parameter(torch.zeros(dim))
                self.v_bias = nn.Parameter(torch.zeros(dim))
                if r > 0:
                    self.q_bias.requires_grad = False
                    self.v_bias.requires_grad = False
            else:
                self.q_bias = None
                self.v_bias = None
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=qkv_bias)
        if self.posembed_type == "rotaty":
            self.position_enc = rope3_maskflatten(window_size, head_dim)

        
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = Linear(dim, dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
        self.proj_drop = nn.Dropout(proj_drop)
        # self.softmax = nn.Softmax(dim=-1)

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        _, D, H, W, _ = x.shape
        img_mask = torch.zeros((1, D, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
        d_slices = (slice(0, -self.window_size[0]),
                    slice(-self.window_size[0], -self.shift_size[0]),
                    slice(-self.shift_size[0], None))
        h_slices = (slice(0, -self.window_size[1]),
                    slice(-self.window_size[1], -self.shift_size[1]),
                    slice(-self.shift_size[1], None))
        w_slices = (slice(0, -self.window_size[2]),
                    slice(-self.window_size[2], 0),
                    slice(0, None))
        cnt = 0
        for d in d_slices:
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, d, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Md*Mh*Mw] - [nW, Md*Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask



    def forward(self, x, x_mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        B, D, H, W, C = x.shape

        if D % self.window_size[0] != 0 and (not (self.window_size == [D, H, W])):
            # Calculate the amount of padding needed
            padding = self.window_size[0] - (D % self.window_size[0])

            # Create a tensor of zeros with the same type as x
            pad_tensor = torch.zeros((B, padding, H, W, C), dtype=x.dtype, device=x.device)

            # Concatenate the padding tensor along the D dimension
            x = torch.cat([x, pad_tensor], dim=1)
            if x_mask is not None:
                mask_pad_tensor = torch.zeros((B, padding, H, W), dtype=x_mask.dtype, device=x_mask.device)
                x_mask = torch.cat([x_mask, mask_pad_tensor], dim=1)
            origin_D = D
            D += padding
        else:
            pad_tensor = None

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))
            if x_mask is not None:
                x_mask = torch.roll(x_mask, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))


        x = window_partition(x, self.window_size)  # [B, nW, Md, Mh, Mw, C]
        x = x.reshape(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)

        if x_mask is not None:
            x_mask = window_partition(x_mask.unsqueeze(-1), self.window_size)  # [B, nW, Md, Mh, Mw, C]
            x_mask = x_mask.reshape(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])

        # pre_x = x.clone()
        B_ = x.shape[0]
        if self.window_size == [D, H, W] and x_mask is not None:
            pre_x = x.clone()
            x_mask_indices = x_mask.nonzero(as_tuple=True)
            x = pre_x[x_mask_indices]
            x_mask = x_mask[x_mask_indices].reshape(pre_x.shape[0], -1)
        # else:
        #     x = pre_x

        B_, N, C = x.reshape(B_, -1, C).shape
        


        # cosine attention
        if self.use_qknorm:
            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
                qkv = self.qkv(x)
                qkv = qkv + qkv_bias
            else:
                qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            q = q * self.logit_scale.to(q)

        if x_mask is not None:
            x_mask = x_mask.unsqueeze(-1).unsqueeze(-1).expand(B_, N, self.num_heads, 1).permute(0, 2, 1, 3).reshape(-1, N)

        if self.posembed_type == "rotaty":
            q = self.position_enc(q.reshape(-1, N, C // self.num_heads), x_mask if self.window_size == [D, H, W] else None).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, N, C // self.num_heads), x_mask if self.window_size == [D, H, W] else None).reshape(B_, self.num_heads, -1, C // self.num_heads)
            if mask is not None:
                attn = q @ k.transpose(-2, -1)
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                if x_mask is not None and self.window_size != [D, H, W]:
                    attn = attn.masked_fill(x_mask.reshape(B_ // nW, nW, self.num_heads, N, 1) == 0, -torch.inf)
                    attn = attn.masked_fill(x_mask.reshape(B_ // nW, nW, self.num_heads, 1, N) == 0, -torch.inf)

                attn = attn.view(-1, self.num_heads, N, N)
                attn = safe_softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
            elif self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                if x_mask is not None and self.window_size != [D, H, W]:
                    # print(attn.shape)
                    # print(x_mask.reshape(B_, self.num_heads, N).shape)
                    attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, N, 1) == 0, -torch.inf)
                    attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, 1, N) == 0, -torch.inf)
                attn = safe_softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
        else:
            attn = q @ k.transpose(-2, -1)
            if x_mask is not None and self.window_size != [D, H, W]:
                attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, N, 1) == 0, -torch.inf)
                attn = attn.masked_fill(x_mask.reshape(B_, self.num_heads, 1, N) == 0, -torch.inf)
            attn = safe_softmax(attn)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)



        x = x.reshape(-1, N, C)
        if self.window_size == [D, H, W] and x_mask is not None:
            pre_x[x_mask_indices] = x.reshape(-1, C)
            x = pre_x
        x = window_reverse(x, self.window_size, D, H, W)

        if self.shift_size[0] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))

        if pad_tensor is not None:
            x = x[:, :origin_D, :, :, :]
        x = self.proj(x)
        x = self.proj_drop(x)
        return x



class flash_SD_attn(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
                 pretrained_window_size=[0, 0], shift_size=[0, 0, 0], use_qknorm=True, 
                 use_flash=False, posembed_type="swinv2", dilated_size=[1,1,1], r=0, lora_alpha=1, lora_dropout=0,
                 fan_in_fan_out=False, merge_weights=True):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.pretrained_window_size = pretrained_window_size
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.use_qknorm = use_qknorm
        self.dilated_size = dilated_size

        self.total_window_size = [window_size[i] * dilated_size[i] for i in range(len(window_size))]
        

        self.use_flash = use_flash
        self.posembed_type=posembed_type
        head_dim = dim // num_heads
        if self.use_qknorm:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)        
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False)
            if qkv_bias:
                self.q_bias = nn.Parameter(torch.zeros(dim))
                self.v_bias = nn.Parameter(torch.zeros(dim))
            else:
                self.q_bias = None
                self.v_bias = None
        else:
            self.logit_scale = torch.tensor(head_dim ** -0.5)
            self.qkv = Linear(dim, dim * 3, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=qkv_bias)

        if posembed_type == "swinv2":
            # mlp to generate continuous relative position bias
            self.cpb_mlp = nn.Sequential(Linear(2, 512, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=True),
                                        nn.ReLU(inplace=True),
                                        Linear(512, num_heads, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights, bias=False))

            # get relative_coords_table
            relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
            relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
            relative_coords_table = torch.stack(
                torch.meshgrid([relative_coords_h,
                                relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
            if pretrained_window_size[0] > 0:
                relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
            else:
                relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
                relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
            relative_coords_table *= 8  # normalize to -8, 8
            relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
                torch.abs(relative_coords_table) + 1.0) / np.log2(8)

            self.register_buffer("relative_coords_table", relative_coords_table)

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)
        elif posembed_type == "rotaty":
            self.position_enc = rope2(window_size, head_dim, origin_shape=pretrained_window_size)

        # self.qkv = nn.Linear(dim, dim * 3, bias=False)
        # if qkv_bias:
        #     self.q_bias = nn.Parameter(torch.zeros(dim))
        #     self.v_bias = nn.Parameter(torch.zeros(dim))
        # else:
        #     self.q_bias = None
        #     self.v_bias = None
        self.attn_drop_rate = torch.tensor(attn_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = Linear(dim, dim, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, fan_in_fan_out=fan_in_fan_out, merge_weights=merge_weights)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def lora(self, mode=True):
        for child in self.children():
            if hasattr(child, "lora"):
                child.lora(mode)

    def create_mask(self, x):
        # calculate attention mask for SW-MSA
        # 保证Hp和Wp是window_size的整数倍
        # Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
        # Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
        # 拥有和feature map一样的通道排列顺序，方便后续window_partition


        if len(self.window_size) == 3:
            _, T, H, W, _ = x.shape
            img_mask = torch.zeros((1, T, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            t_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            h_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], -self.shift_size[1]),
                        slice(-self.shift_size[1], None))
            w_slices = (slice(0, -self.window_size[2]),
                        slice(-self.window_size[2], 0),
                        slice(0, None))
            cnt = 0
            for t in t_slices:
                for h in h_slices:
                    for w in w_slices:
                        img_mask[:, t, h, w, :] = cnt
                        cnt += 1
        elif len(self.window_size) == 2:
            _, H, W, _ = x.shape
            img_mask = torch.zeros((1, H, W, 1), device=x.device)  # [1, Hp, Wp, 1]
            h_slices = (slice(0, -self.window_size[0]),
                        slice(-self.window_size[0], -self.shift_size[0]),
                        slice(-self.shift_size[0], None))
            w_slices = (slice(0, -self.window_size[1]),
                        slice(-self.window_size[1], 0),
                        slice(0, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

        mask_windows = window_partition(img_mask, self.total_window_size)  # [B, nW, Mt, Mh, Mw, C]
        mask_windows = mask_windows.reshape(-1, *self.total_window_size, 1)
        B_ = mask_windows.shape[0]
        if len(self.dilated_size) == 3:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1]*self.dilated_size[2], -1)
        elif len(self.dilated_size) == 2:
            mask_windows = window_partition(mask_windows, self.dilated_size).reshape(B_, -1, 
                                        self.dilated_size[0]*self.dilated_size[1], 1).permute(
                                        0, 2, 1, 3).reshape(B_*self.dilated_size[0]*self.dilated_size[1], -1)


        # mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
        # if len(self.window_size) == 3:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2])  # [nW, Mh*Mw]
        # elif len(self.window_size) == 2:
        #     mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])  # [nW, Mh*Mw]
            
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
        # [nW, Mh*Mw, Mh*Mw]
        attn_mask = attn_mask.masked_fill(attn_mask != 0, -torch.inf).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask



    def forward(self, x):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        _, H, W, C = x.shape

        if (self.shift_size[-1] == 0) or (self.window_size[-1] == W):
            mask = None
        else:
            mask = self.create_mask(x).to(x)

        if self.shift_size[-1] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))

        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]

        x = window_partition(x, self.total_window_size).reshape(-1, *self.total_window_size, C)  # [B, nW, Mt, Mh, Mw, C]
        B = x.shape[0]

        x = window_partition(x, self.dilated_size).reshape(B, -1, 
                            self.dilated_size[0]*self.dilated_size[1], C).permute(
                            0, 2, 1, 3).reshape(B*self.dilated_size[0]*self.dilated_size[1], -1, C)
        B_, N, C = x.shape


        # qkv_bias = None
        # if self.q_bias is not None:
        #     qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]

        # cosine attention
        if self.use_qknorm:
            qkv_bias = None
            if self.q_bias is not None:
                qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
                qkv = self.qkv(x)
                qkv = qkv + qkv_bias
            else:
                qkv = self.qkv(x)
            # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale)).exp()  #[nH, 1, 1]
            q = F.normalize(q, dim=-1) * logit_scale
            k = F.normalize(k, dim=-1)
        else:
            qkv = self.qkv(x)
            qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple) [B_, nH, N, C]
            q = q * self.logit_scale.to(q)

        if self.posembed_type == "rotaty":
            q = self.position_enc(q.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            k = self.position_enc(k.reshape(-1, *self.window_size, C // self.num_heads)).reshape(B_, self.num_heads, -1, C // self.num_heads)
            if mask is not None:
                attn = q @ k.transpose(-2, -1)
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)
            elif self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
        # attn = attn * logit_scale
        elif self.posembed_type == "swinv2":
            relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
            relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

            attn = q @ k.transpose(-2, -1)
            attn = attn + relative_position_bias.to(attn).unsqueeze(0)
            if mask is not None:
                nW = mask.shape[0]
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2)

        else:
            # if mask is not None:
            #     attn = q @ k.transpose(-2, -1)
            #     nW = mask.shape[0]
            #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            #     attn = attn.view(-1, self.num_heads, N, N)
            #     attn = self.softmax(attn)
            #     attn = self.attn_drop(attn)
            #     x = (attn @ v).transpose(1, 2)
            if self.use_flash:
                x = flash_attn_func(q.permute(0, 2, 1, 3), k.permute(0,2,1,3), v.permute(0,2,1,3), self.attn_drop_rate, 1.0)
            else:
                attn = q @ k.transpose(-2, -1)
                attn = self.softmax(attn)
                attn = self.attn_drop(attn)
                x = (attn @ v).transpose(1, 2)

        # if mask is not None:
        #     nW = mask.shape[0]
        #     attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        #     attn = attn.view(-1, self.num_heads, N, N)
        #     attn = self.softmax(attn)
        # else:
        #     attn = self.softmax(attn)

        # attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

        x = x.reshape(B_, N, C)

        x = x.reshape(B, -1, N, C).permute(0, 2, 1, 3).reshape(
                                        -1, self.dilated_size[0]*self.dilated_size[1], C)
        x = window_reverse(x, self.dilated_size, 1, *self.total_window_size)
        

        x = window_reverse(x, self.total_window_size, 1, H, W)

        if self.shift_size[0] > 0 and (self.window_size[-1] != W):
            x = torch.roll(x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))

        x = self.proj(x)
        x = self.proj_drop(x)
        return x
