import torch
from torch import nn, Tensor
import torch.nn.functional as F
from ..builder import BACKBONES
from mmcv.runner import BaseModule
import torch.nn.init as init
from einops import rearrange
from timm.layers import DropPath
from typing import List, Dict, Tuple, Union, Optional, Callable
from functools import partial
from mmcv.runner import _load_checkpoint
from mmseg.utils import get_root_logger

@BACKBONES.register_module()
class FusionFormer(BaseModule):
    def __init__(
        self,
        in_chans: int = 3,
        embed_dim: int = 24,
        depths: List[int] = [4, 4],
        num_heads: List[int] = [8, 16],
        segment_frequencies: List[int] = [2, 1],
        window_bases: List[int] = [7, 7],
        ratio_bases: List[int] = [2, 1],
        qk_head_dims: List[int] = [64, 64],
        v_head_dims: List[int] = [64, 64],
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        ffn_bias: bool = True,
        proj_bias: bool = True,
        drop_path_rate: float = 0.0,
        drop_path_uniform: bool = False,
        init_values = None,
        act_layer: nn.Module = nn.GELU,
        ffn_layer: str = 'mlp',
        fewer_norm: bool = False,
        kernel_norm: bool = True,
        head_enhance: bool = True,
        use_level_embed: bool = False,
        attention_sum: bool = True,
        super_res: bool = False,
        convert_norm: bool = True,
        pretrained_path: str = None,
    ):
        super().__init__()
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.embed_dim = embed_dim * 2
        self.in_chans = in_chans
        self.depths = depths
        self.num_heads = num_heads
        self.super_res = super_res
        
        # building input stems 
        self.stems = nn.ModuleList()
        base_in_chan = embed_dim
        stem_kernel_sizes = [3, 5, 3]
        expand_ratios = [1, 4, 4]
        stem_strides = [1, 2, 1]
        # always use 3 Conv Block for the stem
        self.num_stems = 3
        
        for idx in range(self.num_stems):
            out_channels = [embed_dim * 2 ** (idx + 1)] * 2
            self.stems.append(
                Conv_Block(
                    stem=True if idx == 0 else False,
                    in_channels=base_in_chan,
                    out_channels=[base_in_chan] + out_channels if idx == 0 else out_channels,
                    kernel_size=stem_kernel_sizes[idx],
                    strides=stem_strides if idx == 0 else stem_strides[1:],
                    expand_ratios=expand_ratios if idx == 0 else expand_ratios[1:],
                    activation=act_layer,
                    fewer_norm=fewer_norm)
                )
            base_in_chan = out_channels[-1]
            
        # building transformer layers
        if ffn_layer == "mlp":
            ffn_layer = Mlp
        elif ffn_layer == "identity":
            def f(*args, **kwargs):
                return nn.Identity()
            ffn_layer = f
        else:
            raise NotImplementedError

        if drop_path_uniform is True:
            dpr = [drop_path_rate] * sum(depths)
        else:
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        
        self.layers = nn.ModuleList()
        # super_res settings: input for the BasicFuionLayer from stride 1/32 to 1/64
        downsample_layer_0 = True if super_res else False
        stride_layer_0 = 2 if super_res else 1
        for idx, depth in enumerate(depths):
            layer = BasicFuionLayer(
                dim=base_in_chan,
                out_dim=base_in_chan if idx == 0 else base_in_chan * 2,
                depth=depth,
                num_heads=num_heads[idx],
                qk_head_dim=qk_head_dims[idx],
                v_head_dim=v_head_dims[idx],
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                init_values=init_values,
                drop_path=dpr[sum(depths[:idx]) : sum(depths[: idx + 1])],
                segment_frequency=segment_frequencies[idx],
                window_base=window_bases[idx],
                ratio_base=ratio_bases[idx],
                act_layer=act_layer,
                norm_layer=norm_layer,
                ffn_layer=ffn_layer,
                downsample_stride=stride_layer_0 if idx == 0 else 2,
                downsampel_kernel_size=3,
                use_downsample=downsample_layer_0 if idx == 0 else True,
                kernel_norm=kernel_norm,
                head_enhance=head_enhance,
                use_level_embed=use_level_embed,
                attention_sum=attention_sum,
            )
            
            base_in_chan = layer.out_dim
            self.layers.append(layer)
        
        self.output_dim = base_in_chan
        self.norm = norm_layer(self.output_dim)
        
        self.pretrained_path = pretrained_path
        self.convert_norm = convert_norm
        
        if self.convert_norm:
            self.convert_bn_to_syncbn()
            

    def convert_bn_to_syncbn(self):
        def _convert(module):
            for name, child in module.named_children():
                if isinstance(child, nn.BatchNorm2d):
                    setattr(module, name, nn.SyncBatchNorm.convert_sync_batchnorm(child))
                else:
                    _convert(child)  # 递归转换子模块
            return module
        
        _convert(self)


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.GroupNorm):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

        if self.pretrained_path:
            logger = get_root_logger()
            checkpoint = _load_checkpoint(self.pretrained_path, logger=logger, map_location='cpu')
            if 'state_dict_ema' in checkpoint:
                state_dict = checkpoint['state_dict_ema']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            elif 'model' in checkpoint:
                state_dict = checkpoint['model']
            else:
                state_dict = checkpoint

            missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)  

            logger.info("Missing keys:")
            for key in missing_keys:
                logger.info(f"  {key}")
            logger.info("Unexpected keys:")
            for key in unexpected_keys:
                logger.info(f"  {key}")

    
    def forward_features(self, x: torch.Tensor):
        
        if isinstance(x, list):
            print(f"Input is a list or {len(x)} images")
            nc, h, w = x[0].shape
        else:
            B, nc, h, w = x.shape
        assert nc == self.in_chans
        outs = []
        for idx, stem in enumerate(self.stems):
            x = stem(x)
            if idx > 0:
                outs.append(x)
        
        for idx, layer in enumerate(self.layers):
            x = layer(x)
            if idx != len(self.layers) - 1:
                outs.append(x)
        
        if len(x.shape) == 4:
            h_out, w_out = x.shape[2:]
            x = rearrange(x, 'b c h w -> b (h w) c')
        
        x_norm = self.norm(x)
        
        outs.append(rearrange(x_norm, 'b (h w) c -> b c h w', h=h_out, w=w_out).contiguous())
        
        ret = {
            "outs": outs,
            "img_size": (h, w),
        }
        
        return ret
    
    
    def forward(self, x: torch.Tensor):
        ret = self.forward_features(x)
        return ret



class Mlp(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class ConvLayer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        dilation=1,
        groups=1,
        use_bias=False,
        dropout=0,
        norm="bn2d",
        norm_groups=None,
        act_func=nn.GELU,
    ):
        super(ConvLayer, self).__init__()
        assert kernel_size % 2 != 0, "Kernel size must be odd number"
        padding = kernel_size // 2
        padding *= dilation

        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=padding,
            dilation=(dilation, dilation),
            groups=groups,
            bias=use_bias,
        )
        
        if norm == 'bn2d':
            self.norm = nn.BatchNorm2d(out_channels)
        elif norm == 'gn':
            assert norm_groups is not None, "GroupNorm requires num_groups"
            self.norm = nn.GroupNorm(num_groups=norm_groups, num_channels=out_channels, affine=True)
        else:
            self.norm = None
        
        if self.norm is not None:
            nn.init.constant_(self.norm.weight, 1.0)
            nn.init.constant_(self.norm.bias, 0.0)
        
        self.act = act_func() if act_func is not None else None
        self.init_weights()


    def init_weights(self):
        # Conv Initialization
        init.kaiming_normal_(self.conv.weight, mode="fan_out", nonlinearity="relu")
        if self.conv.bias is not None:
            init.constant_(self.conv.bias, 0)

        # BN or GN init
        if self.norm is not None:
            init.constant_(self.norm.weight, 1.0)
            init.constant_(self.norm.bias, 0.0)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x


class DSConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        act_func=nn.Hardswish,
        fewer_norm=False,
        hidden_channels=None,
    ):
        super(DSConv, self).__init__()

        if not fewer_norm:
            # To be the same as Seaformer stems -> origin (T, F)
            use_bias=(False, False)
            norm=("bn2d", "bn2d")
        else:
            use_bias=(False, False)
            norm=(None, "bn2d")
            
        act_funcs=(act_func, None)

        self.depth_conv = ConvLayer(
            in_channels,
            hidden_channels if hidden_channels is not None else in_channels,
            kernel_size,
            stride,
            groups=in_channels,
            norm=norm[0],
            act_func=act_funcs[0],
            use_bias=use_bias[0],
        )
        self.point_conv = ConvLayer(
            hidden_channels if hidden_channels is not None else in_channels,
            out_channels,
            1,
            norm=norm[1],
            act_func=act_funcs[1],
            use_bias=use_bias[1],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x


class MBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size=3,
        stride=1,
        mid_channels=None,
        expand_ratio=6,
        fewer_norm=False,
        act_func=nn.Hardswish,
    ):
        super(MBConv, self).__init__()
        
        if not fewer_norm:
            # To be the same as Seaformer stems -> origin (T, T, F)
            use_bias=(False, False, False)
            norm=("bn2d", "bn2d", "bn2d")
        else:
            use_bias=(False, False, False)
            norm=(None, None, "bn2d")
            
        act_funcs=(act_func, act_func, None)
        
        mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels

        self.inverted_conv = ConvLayer(
            in_channels,
            mid_channels,
            1,
            stride=1,
            norm=norm[0],
            act_func=act_funcs[0],
            use_bias=use_bias[0],
        )
        self.depth_conv = ConvLayer(
            mid_channels,
            mid_channels,
            kernel_size,
            stride=stride,
            groups=mid_channels,
            norm=norm[1],
            act_func=act_funcs[1],
            use_bias=use_bias[1],
        )
        self.point_conv = ConvLayer(
            mid_channels,
            out_channels,
            1,
            norm=norm[2],
            act_func=act_funcs[2],
            use_bias=use_bias[2],
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.inverted_conv(x)
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x


# Same settings to the InvertedResidual
class ResidualMBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        expand_ratio: int = 2,
        activation: nn.Module=nn.GELU,
        fewer_norm: bool = False,
    ):
        super(ResidualMBConv, self).__init__()
        self.stride = stride
        self.expand_ratio = expand_ratio
        self.out_channels = out_channels
        assert stride in [1, 2]
        
        hidden_dim = int(round(in_channels * expand_ratio))
        self.use_residual = in_channels == out_channels and stride == 1
        
        if expand_ratio != 1:
            self.conv = MBConv(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                mid_channels=hidden_dim,
                fewer_norm=fewer_norm,
                act_func=activation)
        else:
            self.conv = DSConv(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                act_func=activation,
                fewer_norm=fewer_norm)
    
    def forward(self, x: Tensor) -> Tensor:
        if self.use_residual:
            return x + self.conv(x)
        return self.conv(x)


# Same settings to the StackerMV2Block
class Conv_Block(nn.Module):
    def __init__(
        self,
        stem: bool = False,
        in_channels: int = 16,
        out_channels: List[int] = [16, 32, 32],
        kernel_size: int = 3,
        strides: List[int] = [1, 2, 1],
        expand_ratios: List[int] = [1, 4, 3],
        activation: nn.Module = nn.ReLU,
        fewer_norm: bool = False,
    ):
        super(Conv_Block, self).__init__()
        self.stem = stem
        if stem:
            self.stem_block = ConvLayer(3, in_channels, 3, 2, 1, 1, act_func=activation)
        
        self.convs = nn.ModuleList()
        assert len(out_channels) == len(strides) == len(expand_ratios)
        
        for index, (out_channel, stride, expand_ratio) in enumerate(zip(out_channels, strides, expand_ratios)):
            self.convs.append(
                ResidualMBConv(
                    in_channels=in_channels,
                    out_channels=out_channel,
                    kernel_size=kernel_size,
                    stride=stride,
                    expand_ratio=expand_ratio,
                    activation=activation,
                    fewer_norm=fewer_norm)
                    )
            in_channels = out_channel
        
    def forward(self, x: Tensor) -> Tensor:
        if self.stem:
            x = self.stem_block(x)
        for layer in self.convs:
            x = layer(x)
        return x


class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_values: Union[float, Tensor] = 1e-5,
        inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: Tensor) -> Tensor:
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0:
        random_tensor.div_(keep_prob)
    output = x * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


def window_partition(x: Tensor, window_size: Tuple[int, int]) -> Tensor:
    """
    Partitions the input tensor into windows.

    Args:
        x (Tensor): Input tensor to be partitioned.
        window_size (Tuple[int, int]): Size of the window.

    Returns:
        Tensor: Partitioned tensor.
    """
    if len(x.shape) == 4:
        B, H, W, C = x.size()
        windows = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
        windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
    else:
        assert len(x.shape) == 5
        B, num_heads, H, W, d = x.size()
        windows = x.view(B, num_heads, H // window_size[0], window_size[0], W // window_size[1], window_size[1], d)
        windows = windows.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(-1, num_heads, window_size[0], window_size[1], d)
        
    return windows


def window_reverse(x: Tensor, wl: Tuple[int, int], bsz: int, input_resolution: Tuple[int, int]):
    """
    Args:
        x (Tensor): Input tensor to be reversed (B_, wl, wl, num_heads, d)
        wl (Tuple[int, int]): Window length.
        bsz (int): Batch size.
    Returns:
        Tensor: Reversed tensor.
    """
    B_, wh, ww, num_heads, d = x.size()
    assert wh == wl[0] and ww == wl[1], f"Input tensor {x.shape} with {wl}"
    H, W = input_resolution
    num_h = H // wl[0]
    num_w = W // wl[1]
    x = x.view(bsz, num_h, num_w, wh, ww, num_heads, d)
    x = x.permute(0, 1, 3, 2, 4, 5, 6).contiguous().view(bsz, H, W, num_heads, d)
    return x


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
        
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)


class MultiFusionLatentAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        window_lengths: List[Union[int, tuple]] = [8, 16, 32],
        shift_sizes: Optional[List[Union[int, tuple]]] = None,
        dilated_ratios: List[int] = [1, 2, 4],
        qkv_bias: bool = True,
        proj_bias: bool = True,
        proj_drop: float = 0.,
        attn_drop: float = 0.,
        kernel_norm: bool = True,
        head_enhance: bool = True,
        use_level_embed: bool = True,
        attention_sum: bool = True,
        q_lora_rank: int = 128,
        kv_lora_rank: int = 128,
        qk_head_dim: int = 64,
        v_head_dim: int = 64,
    ):
        super(MultiFusionLatentAttention, self).__init__()
        self.embed_dim = dim
        self.num_heads = num_heads
        
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.scale = self.qk_head_dim ** -0.5
        
        self.shift_sizes = shift_sizes
        self.window_lengths = window_lengths
        self.dilated_ratios = dilated_ratios
        
        # Adaptive Local Tokens Fusion
        self.kernel_norm = kernel_norm
        self.head_enhance = head_enhance
        if self.kernel_norm:
            self.kernel_encoder = ResidualMBConv(in_channels=self.num_heads * (self.qk_head_dim * 2 + self.v_head_dim), out_channels=self.num_heads, kernel_size=3, expand_ratio=1, activation=nn.ReLU)
        
        self.use_level_embed = use_level_embed
        self.attention_sum = attention_sum
        self.inner_attn_ln = nn.LayerNorm(self.v_head_dim * self.num_heads)
        
        # Attention Linear Weights
        if self.q_lora_rank == 0:
            self.wq = nn.Linear(self.embed_dim, self.num_heads * self.qk_head_dim, bias=qkv_bias)
        else:
            self.wq_a = nn.Linear(self.embed_dim, self.q_lora_rank, bias=qkv_bias)
            self.q_norm = nn.LayerNorm(self.q_lora_rank)
            self.wq_b = nn.Linear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=qkv_bias)
        self.wkv_a = nn.Linear(self.embed_dim, self.kv_lora_rank, bias=qkv_bias)
        self.kv_norm = nn.LayerNorm(self.kv_lora_rank)
        self.wkv_b = nn.Linear(self.kv_lora_rank, self.num_heads * (self.qk_head_dim + self.v_head_dim), bias=qkv_bias)
        self.wo = nn.Linear(self.num_heads * self.v_head_dim, self.embed_dim, bias=proj_bias)

        # dropouts inside attention oprations
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)
        
        self.sigmoid = h_sigmoid()
        
        
    def get_attn_mask(self, wl: Tuple[int, int], dr: int, shift_size: Tuple[int, int], input_res: Tuple[int]) -> torch.Tensor:
        assert shift_size[0] % dr == 0 and shift_size[1] % dr == 0, f"shift size with {shift_size} for dr {dr}"
        
        H, W = input_res
        img_mask = torch.zeros((1, H, W, 1))
        h_slices = (slice(0, -wl[0]),
                        slice(-wl[0], -shift_size[0]),
                        slice(-shift_size[0], None))
        w_slices = (slice(0, -wl[1]),
                    slice(-wl[1], -shift_size[1]),
                    slice(-shift_size[1], None))
        
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1
        # nw, wl, wl, 1
        mask_windows = window_partition(img_mask, wl)
        if dr > 1:
            # assert wl[0] % dr == 0 and wl[1] % dr == 0 -> already checked for the mean of mask
            mask_windows = mask_windows.squeeze(-1).view(-1, wl[0] // dr, dr, wl[1] // dr, dr).permute(0, 1, 3, 2, 4).contiguous().flatten(1, 2).flatten(2, 3)
            mask_windows = mask_windows.mean(dim=-1, keepdim=False)
        else:
            mask_windows = mask_windows.view(-1, wl[0] * wl[1])
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        
        return attn_mask
    
    
    def dense_window_to_sparse(self,
                               x: Tensor,
                               dr: int):
        """
        Args:
            x (Tensor): q/k/v of shape (B_, heads, wh, ww, d)
            dr (int): dilated ratio
            wl (int): window length
        Returns:
            Tensor: q/k/v of shape (B_, h, (wl // dr)^2, d)
            B_ = B * num_windows, d = head_dim
        """
        B_, num_heads, window_height, window_width, d = x.shape
        x = x.view(B_, num_heads, window_height // dr, dr, window_width // dr, dr, d)
        x = x.permute(0, 1, 2, 4, 3, 5, 6).contiguous().view(B_, num_heads, -1, dr * dr, d)
        x = x.mean(dim=-2, keepdim=False)
        
        return x
    
    
    def gathering(self, 
                  x: Tensor,
                  dr: int,
                  wl: Tuple[int, int],
                  shift_size: Optional[Tuple[int, int]] = None,
                  kernel_norm: Tensor = None):
        """
        Args:
            x (Tensor): q/k/v of shape (B, heads, H, W, d)
            dr (int): dilated ratio
            wl (Tuple[int]): window length
            shift_size tuple(int): shift size
            kernel_norm (Tensor or None): (B, num_heads, H, W)
        Returns:
            Tensor: q/k/v of shape (B_, h, (wl // dr)^2, d)
            B_ = B * num_windows, d = head_dim
        """
        
        if shift_size is not None:
            shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(2, 3))
        else:
            shifted_x = x
        
        if kernel_norm is not None:
            shifted_x = shifted_x * kernel_norm.unsqueeze(-1)
        # (bsz, num_heads, H, W, d) -> (B_, num_heads, wh, ww, d)
        x_windows = window_partition(shifted_x, wl)
        
        if dr > 1:
            x_windows = self.dense_window_to_sparse(x_windows, dr)
        else:
            assert dr == 1
            x_windows = x_windows.flatten(2, 3)
            
        return x_windows
    
    
    def sparse_to_dense_window(self,
                               out: Tensor, lse: Tensor,
                               bsz: int, dr: int, 
                               wl: Tuple[int, int],
                               input_res: Tuple[int],
                               kernel_norm: Optional[Tensor] = None):
        """
        Args:
            out: (B_, num_heads, (wl // dr) ** 2, C // num_heads)
            lse: (B_, num_heads, (wl // dr) ** 2)
            kernel_norm: Optional (B, num_heads, H, W)
        Returns:
            out: (B_, num_heads, window_size, window_size, C // num_heads)
            lse: (B_, num_heads, window_size, window_size, 1)
        """
        B_, num_heads, fused_seg_len, d = out.shape
        fused_wh, fused_ww = wl[0] // dr, wl[1] // dr
        assert fused_seg_len == fused_wh * fused_ww
        if dr > 1:
            out = out.view(B_, num_heads, fused_wh, 1, fused_ww, 1, d)
            out = out.expand(-1, -1, -1, dr, -1, dr, -1).contiguous()
            out = out.view(B_, num_heads, wl[0], wl[1], d)
            out = out.permute(0, 2, 3, 1, 4).contiguous()

            lse = lse.view(B_, num_heads, fused_wh, 1, fused_ww, 1)
            lse = lse.expand(-1, -1, -1, dr, -1, dr).contiguous()
            lse = lse.view(B_, num_heads, wl[0], wl[1], 1)
            lse = lse.permute(0, 2, 3, 1, 4).contiguous()
        else:
            out = out.view(B_, num_heads, wl[0], wl[1], d)
            out = out.permute(0, 2, 3, 1, 4).contiguous()
            
            lse = lse.view(B_, num_heads, wl[0], wl[1], 1)
            lse = lse.permute(0, 2, 3, 1, 4).contiguous()
            
        # -> (bsz, H, W, num_heads, d)
        out = window_reverse(out, wl, bsz, input_res)
        lse = window_reverse(lse, wl, bsz, input_res)
        
        if kernel_norm is not None:
            out = out * kernel_norm.permute(0, 2, 3, 1).contiguous().unsqueeze(-1)
            lse = lse * kernel_norm.permute(0, 2, 3, 1).contiguous().unsqueeze(-1)
        
        return out, lse
        

    def scattering(self, 
                   outs: List[Tensor],
                   lses: List[Tensor],
                   bsz: int,
                   window_lengths: List[Tuple[int]],
                   input_res: Tuple[int],
                   kernel_norms: Optional[List[Tensor]] = None,
                   shift_sizes: Optional[List[Tuple[int]]] = None):
        assert len(outs) == len(lses)
        all_outs, all_lses = [], []
        
        for idx, (o, lse) in enumerate(zip(outs, lses)):
            dr = self.dilated_ratios[idx]
            wl = window_lengths[idx]
            
            # o: (B_, n_heads, (wl // dr)^2, d)
            #    -> (bsz, H, W, num_heads, d)
            # lse: (B_, n_heads, (wl // dr)^2, (wl // dr)^2)
            #    -> (bsz, H, W, num_heads, 1)
            o, lse = self.sparse_to_dense_window(out=o, lse=lse, 
                                                 bsz=bsz, dr=dr, wl=wl,
                                                 input_res=input_res,
                                                 kernel_norm=kernel_norms[idx])
            # reverse cyclic shift
            if shift_sizes is not None:
                shift_size = shift_sizes[idx]
                if shift_size is not None:
                    o = torch.roll(o, shifts=shift_size, dims=(1, 2))
                    lse = torch.roll(lse, shifts=shift_size, dims=(1, 2))
            
            # all_outs: [(bsz, H * W, num_heads, d), ...]
            all_outs.append(o.flatten(1, 2))
            all_lses.append(lse.flatten(1, 2))
            
        out = 0
        if self.attention_sum and len(all_outs) > 1:
            with torch.no_grad():
                max_lse = torch.stack(all_lses, dim=0).max(dim=0)[0]
                exp_lses = [torch.exp(lse - max_lse) for lse in all_lses]
                lse_sum = torch.stack(exp_lses, dim=0).sum(dim=0)
                exp_lses = [lse / lse_sum for lse in exp_lses]

        if len(all_outs) > 1:
            for idx, o in enumerate(all_outs):
                if self.attention_sum:
                    o = o * exp_lses[idx].type_as(o)
                if self.use_level_embed:
                    o_level = self.level_embed[:, idx].unsqueeze(1).expand(bsz, -1, -1, -1)
                    o = o + o_level
                
                out = out + o / len(outs)
        else:
            out = all_outs[-1]
                
        out = out.flatten(2, 3)
        
        return out, all_lses
    

    def window_attention_ops(self, 
                             q: Tensor, k: Tensor, v: Tensor,
                             attn_mask: Optional[Tensor] = None):
        B_, n_heads, N, _ = q.shape
        attn = torch.einsum('bhnd,bhmd->bhnm', q, k) * self.scale
        
        if attn_mask is not None:
            # attn_mask: (num_windows, wh * ww, wh * ww)
            num_windows = attn_mask.shape[0]
            attn = attn.view(B_ // num_windows, num_windows, n_heads, N, N)
            attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(B_, n_heads, N, N)
            attn_weights = F.softmax(attn, dim=-1)
        else:
            attn_weights = F.softmax(attn, dim=-1)
        
        attn_probs = self.attn_drop(attn_weights)
        out = torch.einsum('bhnm,bhmd->bhnd', attn_probs, v)
        
        lse = torch.logsumexp(attn_probs, dim=-1)
        
        return out, lse
    
    
    def forward(self, x: Tensor, freqs_cis: Tensor = None, input_res: Tuple[int] = None,
                return_weights: bool = False):
        assert input_res is not None
        bsz, seq_len, _ = x.shape
        H, W = input_res
        
        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        kv = self.wkv_b(self.kv_norm(self.wkv_a(x)))
        
        if self.kernel_norm:
            qkv = torch.cat((q, kv), dim=-1).transpose(1, 2).contiguous()
            qkv = qkv.view(bsz, -1, H, W)
            kernel_feat = self.kernel_encoder(qkv) # (bsz, num_heads, H, W)
            kernel_feat = self.sigmoid(kernel_feat)

        # -> (bsz, num_heads, H * W, head_dim)
        q = q.view(bsz, seq_len, self.num_heads, self.qk_head_dim).permute(0, 2, 1, 3).contiguous()
        kv = kv.view(bsz, seq_len, self.num_heads, self.qk_head_dim + self.v_head_dim).permute(0, 2, 1, 3).contiguous()
        k, v = torch.split(kv, [self.qk_head_dim, self.v_head_dim], dim=-1)
        
        if freqs_cis is not None:
            # freqs_cis: (num_heads, H * W, head_dim // 2)
            q, k = apply_rotary_emb(q, k, freqs_cis)
        
        q = q.view(bsz, self.num_heads, H, W, self.qk_head_dim)
        k = k.view(bsz, self.num_heads, H, W, self.qk_head_dim)
        v = v.view(bsz, self.num_heads, H, W, self.v_head_dim)
        
        outs = []
        lses = []
        kernel_norms = []
        shift_sizes = []
        suited_windows = []
        
        for index, (wl, dr) in enumerate(zip(self.window_lengths, self.dilated_ratios)): # (wh, ww)
            wh, ww = wl
            if min(wl) > min(input_res) or H % wh != 0 or W % ww != 0:
                wl = (H, W)
                shift_size = None
            else:
                if self.shift_sizes is not None:
                    shift_size = self.shift_sizes[index] # (shift_h, shift_w)
                    if min(wl) >= min(input_res):
                        shift_size = None
                else:
                    shift_size = None
                    
            suited_windows.append(wl)
            
            if self.kernel_norm and dr > 1:
                # (1, num_heads, H, W) with values of 0 or 1
                if self.head_enhance:
                    heads_dilated_index = get_dilated_window_index(window_size=input_res, num_heads=self.num_heads, ratio=dr).unsqueeze(0).to(x.device)
                    # kernel_feat: (bsz, num_heads, H, W)
                    kernel_feat_i = kernel_feat + heads_dilated_index
                else:
                    kernel_feat_i = kernel_feat
                    
                if shift_size is not None:
                    kernel_feat_i = torch.roll(kernel_feat_i, shifts=(-shift_size[0], -shift_size[1]), dims=(2, 3))
                    attn_mask_i = self.get_attn_mask(wl, dr, shift_size, input_res).to(x.device)
                else:
                    attn_mask_i = None
            else:
                kernel_feat_i = None
                attn_mask_i = None
            
            # gathering process
            qi = self.gathering(q, dr, wl, shift_size, kernel_feat_i)
            ki = self.gathering(k, dr, wl, shift_size, kernel_feat_i)
            vi = self.gathering(v, dr, wl, shift_size, kernel_feat_i)
            
            out, lse = self.window_attention_ops(
                qi, ki, vi, attn_mask_i)
            
            outs.append(out)
            lses.append(lse)
            kernel_norms.append(kernel_feat_i)
            shift_sizes.append(shift_size)
            
            if min(wl) > min(input_res) or H % wh != 0 or W % ww != 0:
                break
            
        # scattering process to the original resolution    
        all_out, all_lses = self.scattering(outs=outs, lses=lses, 
                                            bsz=bsz, window_lengths=suited_windows,
                                            input_res=input_res, kernel_norms=kernel_norms, shift_sizes=shift_sizes)
        
        # inner_LN for the attention sum 
        if self.inner_attn_ln is not None:
            all_out = self.inner_attn_ln(all_out)
        
        all_out = self.wo(all_out)
        all_out = self.proj_drop(all_out)
        
        if return_weights:
            return all_out, all_lses
        else:
            return all_out


def get_dilated_window_index(window_size: Union[int, Tuple[int, int]], num_heads: int, ratio: int) -> torch.Tensor:
    assert num_heads % (ratio**2) == 0, f'num_heads must be divisible by ratio^2 : {num_heads % (ratio**2)}'
    if isinstance(window_size, int):
        index_matrices = torch.zeros((ratio**2, window_size, window_size), dtype=torch.int, requires_grad=False)
    else:
        assert isinstance(window_size, tuple) and len(window_size) == 2
        H, W = window_size
        index_matrices = torch.zeros((ratio**2, H, W), dtype=torch.int, requires_grad=False)
        
    matrix_index = 0
    for i in range(ratio):
        for j in range(ratio):
            index_matrices[matrix_index, i::ratio, j::ratio] = 1
            matrix_index += 1
    
    index_matrices = index_matrices.repeat(num_heads // (ratio**2), 1, 1)
    assert index_matrices.shape[0] == num_heads
    return index_matrices


class MultiFusionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        qk_head_dim: int,
        v_head_dim: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        proj_bias: bool = True,
        ffn_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        init_values=None,
        drop_path: float = 0.0,
        window_lengths: List[Union[int, tuple]] = [8, 16, 32],
        dilated_ratios: List[int] = [1, 2, 4],
        shift_window: bool = False,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        attn_class: Callable[..., nn.Module] = MultiFusionLatentAttention,
        ffn_layer: Callable[..., nn.Module] = Mlp,
        kernel_norm: bool = True,
        head_enhance: bool = True,
        use_level_embed: bool = True,
        attention_sum: bool = False,
    ):
        super().__init__()
        self.dim = dim
        self.window_lengths = window_lengths
        self.dilated_ratios = dilated_ratios
        self.shift_window = shift_window
        
        self.shift_sizes = []
        if shift_window:
            for wl in window_lengths:
                wl = wl if isinstance(wl, tuple) else (wl, wl)
                shift_size = (wl[0] // 2, wl[1] // 2)
                self.shift_sizes.append(shift_size)
        else:
            self.shift_sizes = None
        
        self.norm1 = norm_layer(dim)
        self.attn = attn_class(
            dim = dim,
            num_heads = num_heads,
            window_lengths = window_lengths,
            shift_sizes = self.shift_sizes,
            dilated_ratios = self.dilated_ratios,
            qkv_bias = qkv_bias,
            proj_bias = proj_bias,
            attn_drop = attn_drop,
            proj_drop = drop,
            kernel_norm = kernel_norm,
            head_enhance = head_enhance,
            use_level_embed = use_level_embed,
            attention_sum = attention_sum,
            q_lora_rank = dim // 4,
            kv_lora_rank = dim // 4,
            qk_head_dim = qk_head_dim,
            v_head_dim = v_head_dim,
        )        
        
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ffn_layer(
            in_features = dim,
            hidden_features = mlp_hidden_dim,
            out_features = dim,
            act_layer = act_layer,
            drop = drop,
            bias = ffn_bias,
        )
        
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.sample_drop_ratio = drop_path
    
    def forward(self, x: Tensor, freqs_cis: Tensor) -> Tensor:
        assert len(x.shape) == 4, f"Input feature must be 4D tensor for fusion window process -> {x.shape}"
        bsz, in_dim, H, W = x.shape
        assert in_dim == self.dim
        assert freqs_cis is not None, "freqs_cis must be provided"
        x = x.flatten(2, 3).transpose(1, 2)
        
        def attn_residual_func(x: Tensor, freqs_cis: Tensor, input_res: Tuple[int]) -> Tensor:
            return self.ls1(self.attn(self.norm1(x), freqs_cis, input_res))

        def ffn_residual_func(x: Tensor) -> Tensor:
            return self.ls2(self.mlp(self.norm2(x)))

        if self.training and self.sample_drop_ratio > 0.1:
            x = drop_add_residual_stochastic_depth(
                x,
                residual_func=partial(attn_residual_func, freqs_cis=freqs_cis, input_res=(H, W)),
                sample_drop_ratio=self.sample_drop_ratio,
            )
            x = drop_add_residual_stochastic_depth(
                x,
                residual_func=ffn_residual_func,
                sample_drop_ratio=self.sample_drop_ratio,
            )
        elif self.training and self.sample_drop_ratio > 0.0:
            x = x + self.drop_path1(attn_residual_func(x, freqs_cis=freqs_cis, input_res=(H, W)))
            x = x + self.drop_path2(ffn_residual_func(x)) 
        else:
            x = x + attn_residual_func(x, freqs_cis=freqs_cis, input_res=(H, W))
            x = x + ffn_residual_func(x)
        
        # rearrange for the next block's input_res
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        return x


def drop_add_residual_stochastic_depth(
    x: Tensor,
    residual_func: Callable[[Tensor], Tensor],
    sample_drop_ratio: float = 0.0,
) -> Tensor:
    # 1) extract subset using permutation
    b, n, d = x.shape
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
    x_subset = x[brange]

    # 2) apply residual_func to get residual
    residual = residual_func(x_subset)

    x_flat = x.flatten(1)
    residual = residual.flatten(1)

    residual_scale_factor = b / sample_subset_size

    # 3) add the residual
    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
    return x_plus_residual.view_as(x)


def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    # -> (2, num_heads, dim // 2)
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)        
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs


def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
    # freqs: (2, depth, num_heads * dim // 2)
    # t_x, t_y: (N,)
    N = t_x.shape[0]
    depth = freqs.shape[1]
    # No float 16 for this range
    with torch.cuda.amp.autocast(enabled=False):
        # fixed shape with -> (depth, N, num_heads, dim // 2) without permute
        # if permute than check the attention to v3 version
        freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
        freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
        freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)

    return freqs_cis 


def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode='floor').float()
    return t_x, t_y


class BasicFuionLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        out_dim: int,
        depth: int,
        num_heads: int,
        qk_head_dim: int = 64,
        v_head_dim: int = 64,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        proj_bias: bool = True,
        ffn_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        init_values=None,
        drop_path: List[float] = [0.0],
        segment_frequency: int = 3,
        window_base: int = 7,
        ratio_base: int = 1,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        attn_class: Callable[..., nn.Module] = MultiFusionLatentAttention,
        ffn_layer: Callable[..., nn.Module] = Mlp,
        downsample_stride: int = 2,
        downsampel_kernel_size: int = 3,
        use_downsample: bool = True,
        merging_downsample: bool = False,
        kernel_norm: bool = True,
        head_enhance: bool = True,
        use_level_embed: bool = True,
        attention_sum: bool = False,
    ):
        super().__init__()

        self.dim = dim
        self.out_dim = out_dim
        self.depth = depth
        assert depth % 2 == 0, "depth must be even for shift window operation"
        
        self.window_lengths = []
        self.dilated_ratios = []
        
        # downsample layer
        self.downsample_stride = downsample_stride
        self.merging_downsample = merging_downsample
        if use_downsample:
            assert downsample_stride > 1
            if self.merging_downsample:
                assert out_dim == dim, "out_dim must be equal to dim for merging downsample"
                self.downsample = nn.AvgPool2d(kernel_size=downsample_stride, stride=downsample_stride)
            else:
                self.downsample = ResidualMBConv(
                    in_channels=dim,
                    out_channels=out_dim,
                    kernel_size=downsampel_kernel_size,
                    stride=downsample_stride,
                    expand_ratio=4,
                )
        else:
            self.downsample_stride = 1
            self.downsample = None
        
        # prepare for the window_lengths and dilated_ratios
        window_length = [window_base * (2 ** i) for i in range(segment_frequency)]
        assert segment_frequency <= 4, "segment frequency must be less than 4 for computational efficiency"
        assert ratio_base in [1, 2]
        dilated_ratio = [1, 2, 2, 4][:segment_frequency] if ratio_base == 1 else [2, 2, 4, 4][:segment_frequency]
        for wl, dr in zip(window_length, dilated_ratio):
            wl = (wl, wl)
            self.window_lengths.append(wl)
            self.dilated_ratios.append(dr)
        
        # prepare the 2D RoPE metric
        freqs = []
        for idx in range(depth):
            # init 2d freq with: (2, num_heads, dim // 2)
            freqs.append(
                init_random_2d_freqs(dim=qk_head_dim, num_heads=num_heads, theta=100.0, rotate=True)
            )
        # freqs: (2, depth, num_heads, dim // 2) -> (2, depth, num_heads * dim // 2)
        freqs = torch.stack(freqs, dim=1).view(2, depth, -1)
        self.freqs = nn.Parameter(freqs.clone(), requires_grad=True)
        # compute the complex cis
        self.compute_cis = partial(compute_mixed_cis, num_heads=num_heads)
        
        # building Fusion Blocks
        assert len(drop_path) == depth, "Inconsistent drop path size"
        self.blocks = nn.ModuleList([
            MultiFusionBlock(
                dim=out_dim,
                num_heads=num_heads,
                qk_head_dim=qk_head_dim,
                v_head_dim=v_head_dim,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                drop=drop,
                attn_drop=attn_drop,
                init_values=init_values,
                drop_path=drop_path[i],
                window_lengths=self.window_lengths,
                dilated_ratios=self.dilated_ratios,
                shift_window=True if (i + 1) % 2 == 0 else False,
                act_layer=act_layer,
                norm_layer=norm_layer,
                attn_class=attn_class,
                ffn_layer=ffn_layer,
                kernel_norm=kernel_norm,
                head_enhance=head_enhance, 
                use_level_embed=use_level_embed,
                attention_sum=attention_sum,
            ) for i in range(depth)
        ])
        
    def forward(self, x: Tensor) -> Tensor:
        assert len(x.shape) == 4, f"Input feature must be 4D tensor for fusion window process -> {x.shape}"
        
        bsz, in_dim, H_in, W_in = x.shape
        assert in_dim == self.dim, f"Input feature wrong shape {x.shape} with {self.dim}"
        
        if self.downsample_stride > 1:
            H, W = H_in // self.downsample_stride, W_in // self.downsample_stride
        else:
            H, W = H_in, W_in
        
        # fixed to (depth, N, num_heads, dim // 2) if v4 else (depth, num_heads, N, dim // 2) for v3
        t_x, t_y = init_t_xy(end_x = W, end_y = H)
        freqs_cis = self.compute_cis(self.freqs, t_x.to(x.device), t_y.to(x.device))
        
        # downsample operation
        if self.downsample is not None:
            x = self.downsample(x)
        
        # Self-Attention Fusion Blocks
        for idx, blk in enumerate(self.blocks):
            x = blk(x, freqs_cis=freqs_cis[idx])

        return x.contiguous()
