# =============================================================================
# Efficient Pytorch implementation of DuSA2 (Blockify Strategy 2) for ViTs is even faster than the version reported in the paper.
# =============================================================================

import math
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
import torch.utils.checkpoint as checkpoint
import numpy as np
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models._features_fx import register_notrace_function

def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }

default_cfgs = {
    'cswin_224': _cfg(),
    'cswin_384': _cfg(
        crop_pct=1.0,
        input_size=(3, 384, 384)
    ),
}

class Mlp(nn.Module):
    
    def __init__(self, in_features, hidden_features = None, out_features = None, act_layer = nn.GELU, drop = 0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class DuSA(nn.Module):
    
    def __init__(self, dim, split_size = 7, num_heads = 8, attn_drop = 0.):
        super().__init__()
        self.dim = dim
        self.split_size = split_size
        self.num_heads = num_heads
        self.get_v  = nn.Conv2d(dim, dim, kernel_size = 3, stride = 1, padding = 1, groups = dim)
        self.attn_drop = nn.Dropout(attn_drop)
    
    def get_lepe(self, x, func):
        B, N, C = x.shape
        H = W = int(np.sqrt(N))
        if self.split_size**2 < N:
            H_sp, W_sp = self.split_size, self.split_size
            x = x.view(B, H // H_sp, H_sp, W // W_sp, W_sp, C)
            x = x.permute(0, 1, 3, 5, 2, 4).reshape(-1, C, H_sp, W_sp)
        else:
            H_sp, W_sp = H, W
            x = x.transpose(-1, -2).reshape(B, C, H_sp, W_sp)
        lepe = func(x)
        lepe = lepe.view(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).transpose(-1, -2)
        if self.split_size**2 < N:
            x = x.view(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).transpose(-1, -2).contiguous()
        else:
            x = x.view(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).transpose(-1, -2)
        return x, lepe
    
    def blockify(self, x):
        B, N, C  = x.shape
        E = C // self.num_heads
        H = int(math.sqrt(N))
        W = H
        grid_height = H // self.split_size
        grid_width  = W // self.split_size
        x  = x.view(B, grid_height, self.split_size, grid_width, self.split_size, self.num_heads, E).permute(0, 1, 3, 5, 2, 4, 6).reshape(B, grid_height * grid_width, self.num_heads, self.split_size * self.split_size, E)
        x1 = x.view(B * grid_height * grid_width, self.num_heads, -1, E)
        x2 = x.transpose(1, 2).reshape(B, self.num_heads, grid_height * grid_width, -1)
        return x1, x2, grid_height, grid_width
    
    def mhreshape(self, x):
        B, N, C = x.shape
        x = x.view(-1, N, self.num_heads, C // self.num_heads).transpose(1, 2)
        return x
    
    def forward(self, qkv):
        B, N, C = qkv[0].shape
        q, k, v = qkv[0], qkv[1], qkv[2]
        if self.split_size**2 < N:
            q1, q2, grid_height, grid_width = self.blockify(q)
            k1, k2, _, _                    = self.blockify(k)
            v1, lepe = self.get_lepe(v, self.get_v)
            h1 = F.scaled_dot_product_attention(q1, k1, v1, dropout_p = self.attn_drop.p) + lepe
            h1 = deblockify1(h1, B, self.split_size)
            h2 = F.scaled_dot_product_attention(q2, k2, h1, dropout_p = self.attn_drop.p)
            x  = deblockify2(h2, B, self.split_size, grid_height, grid_width)
        else:
            q = self.mhreshape(q)
            k = self.mhreshape(k)
            v, lepe = self.get_lepe(v, self.get_v)
            h = F.scaled_dot_product_attention(q, k, v, dropout_p = self.attn_drop.p) + lepe
            x = h.transpose(1, 2).view(B, N, C)
        return x

class AttnBlock_DuSA(nn.Module):
    
    def __init__(self, dim, reso, num_heads,
                 split_size = 7, mlp_ratio = 4., qkv_bias = False,
                 drop = 0., attn_drop = 0., drop_path = 0.,
                 act_layer = nn.GELU, norm_layer = nn.LayerNorm, stage_ind = 0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.patches_resolution = reso
        self.split_size = split_size
        self.mlp_ratio = mlp_ratio
        self.norm1 = norm_layer(self.dim)
        self.qkv = nn.Linear(self.dim, self.dim * 3, bias=qkv_bias)
        self.branch_num = 1
        self.stage_ind = stage_ind
        self.attns = nn.ModuleList(
            [
                DuSA(self.dim, split_size = self.split_size, num_heads = self.num_heads, attn_drop = attn_drop)
                for i in range(self.branch_num)
            ])
        self.proj = nn.Linear(self.dim, self.dim)
        mlp_hidden_dim = int(self.dim * self.mlp_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(self.dim)
        self.mlp = Mlp(in_features = self.dim, hidden_features = mlp_hidden_dim, out_features = self.dim, act_layer = act_layer, drop = drop)
    
    def forward(self, x):
        H = W = self.patches_resolution
        B, N, C = x.shape
        assert N == H * W
        img = self.norm1(x)
        qkv = self.qkv(img).reshape(B, -1, 3, C).permute(2, 0, 1, 3)
        if self.stage_ind > 0 and self.stage_ind < 3:
            assert self.split_size**2 < N
        attened_x = self.attns[0](qkv)
        attened_x = self.proj(attened_x)
        x = x + self.drop_path(attened_x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

@register_notrace_function
def deblockify1(x, B, split_size: int):
    B1, h, _, E = x.shape
    T = B1 // B
    x = x.view(B, T, h, split_size * split_size, E)
    x = x.transpose(1, 2).reshape(B, h, T, -1)
    return x

@register_notrace_function
def deblockify2(x, B, split_size: int, grid_height: int, grid_width: int):
    _, h, _, _ = x.shape
    x = x.view(B, h, grid_height, grid_width, split_size, split_size, -1).permute(0, 2, 4, 3, 5, 1, 6).reshape(B, grid_height * split_size * grid_width * split_size, -1)
    return x

class Merge_Block(nn.Module):
    
    def __init__(self, dim, dim_out, norm_layer = nn.LayerNorm):
        super().__init__()
        self.conv = nn.Conv2d(dim, dim_out, kernel_size = 3, stride = 2, padding = 1)
        self.norm = norm_layer(dim_out)
    
    def forward(self, x):
        B, new_HW, C = x.shape
        H = W = int(np.sqrt(new_HW))
        x = x.transpose(-2, -1).reshape(B, C, H, W)
        x = self.conv(x)
        B, C = x.shape[:2]
        x = x.view(B, C, -1).transpose(-2, -1)
        x = self.norm(x)
        return x

class DuSATransformer(nn.Module):
    
    def __init__(self, img_size = 224, in_chans = 3, num_classes = 1000, embed_dim = 64, depth = [2, 4, 18, 1], split_size = [7, 7, 14, 7],
                 num_heads = [2, 4, 8, 16], mlp_ratio = 4., qkv_bias = True, drop_rate = 0., attn_drop_rate = 0.,
                 drop_path_rate = 0.2, norm_layer = nn.LayerNorm, use_chk = False, stride = 4, padding = 2, kernel_size_conv = 7):
        super().__init__()
        self.use_chk = use_chk
        self.num_classes = num_classes
        self.num_features = embed_dim
        heads = num_heads
        temp_h = (img_size + padding * 2 - (kernel_size_conv - 1) - 1) // stride + 1
        self.stage1_conv_embed = nn.Sequential(
            nn.Conv2d(in_chans, self.num_features, kernel_size_conv, stride, padding),
            Rearrange('b c h w -> b (h w) c', h = temp_h, w = temp_h),
            nn.LayerNorm(self.num_features)
        )
        curr_dim = self.num_features
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))]
        self.stage1 = nn.ModuleList(
            [
                AttnBlock_DuSA(
                    dim = curr_dim, reso = temp_h, num_heads = heads[0], split_size = split_size[0],
                    mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate, attn_drop = attn_drop_rate,
                    drop_path = dpr[i], norm_layer = norm_layer, stage_ind = 1)
                for i in range(depth[0])
            ])
        self.merge1 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage2 = nn.ModuleList(
            [
                AttnBlock_DuSA(
                    dim = curr_dim, reso = temp_h // 2, num_heads = heads[1], split_size = split_size[1],
                    mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate, attn_drop = attn_drop_rate,
                    drop_path = dpr[np.sum(depth[:1]) + i], norm_layer = norm_layer, stage_ind = 2)
                for i in range(depth[1])
            ])
        self.merge2 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage3 = nn.ModuleList(
            [
                AttnBlock_DuSA(
                    dim = curr_dim, reso = temp_h // 4, num_heads = heads[2], split_size = split_size[2],
                    mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate, attn_drop = attn_drop_rate,
                    drop_path = dpr[np.sum(depth[:2]) + i], norm_layer = norm_layer, stage_ind = 3)
                for i in range(depth[2])
            ])
        self.merge3 = Merge_Block(curr_dim, curr_dim * 2)
        curr_dim = curr_dim * 2
        self.stage4 = nn.ModuleList(
            [
                AttnBlock_DuSA(
                    dim = curr_dim, reso = temp_h // 8, num_heads = heads[3], split_size = split_size[3],
                    mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate, attn_drop = attn_drop_rate,
                    drop_path = dpr[np.sum(depth[:-1]) + i], norm_layer = norm_layer, stage_ind = 4)
                for i in range(depth[-1])
            ])
        self.norm = norm_layer(curr_dim)
        self.out_dim = curr_dim
        self.head = nn.Linear(self.out_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
        trunc_normal_(self.head.weight, std = 0.02)
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std = 0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def get_classifier(self):
        return self.head
    
    def reset_classifier(self, num_classes, global_pool = ''):
        if self.num_classes != num_classes:
            print ('reset head to', num_classes)
            self.num_classes = num_classes
            self.head = nn.Linear(self.out_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
            self.head = self.head.cuda()
            trunc_normal_(self.head.weight, std = 0.02)
            if self.head.bias is not None:
                nn.init.constant_(self.head.bias, 0)
    
    def forward_features(self, x):
        x = self.stage1_conv_embed(x)
        for blk in self.stage1:
            if self.use_chk:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        for pre, blocks in zip([self.merge1, self.merge2, self.merge3], 
                               [self.stage2, self.stage3, self.stage4]):
            x = pre(x)
            for blk in blocks:
                if self.use_chk:
                    x = checkpoint.checkpoint(blk, x)
                else:
                    x = blk(x)
        x = self.norm(x)
        return torch.mean(x, dim = 1)
    
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

# CSWin-DuSA series
@register_model
def CSWin_DuSA_64_24181_tiny_224():
    model = DuSATransformer(embed_dim = 64, depth = [2, 4, 18, 1],
        split_size = [14, 14, 14, 7], num_heads = [2, 4, 8, 16], mlp_ratio = 4., drop_path_rate = 0.2)
    model.default_cfg = default_cfgs['cswin_224']
    return model

@register_model
def CSWin_DuSA_96_36292_base_224():
    model = DuSATransformer(embed_dim = 96, depth = [3, 6, 29, 2],
        split_size = [14, 14, 14, 7], num_heads = [4, 8, 16, 32], mlp_ratio = 4., drop_path_rate = 0.5)
    model.default_cfg = default_cfgs['cswin_224']
    return model

@register_model
def CSWin_DuSA_96_36292_base_384():
    model = DuSATransformer(img_size = 384, embed_dim = 96, depth = [3, 6, 29, 2],
        split_size = [24, 24, 24, 12], num_heads = [4, 8, 16, 32], mlp_ratio = 4., drop_path_rate = 0.7)
    model.default_cfg = default_cfgs['cswin_384']
    return model

# Swin-DuSA series
@register_model
def Swin_DuSA_96_2262_tiny_224():
    model = DuSATransformer(embed_dim = 96, depth = [2, 2, 6, 2],
        split_size = [14, 14, 14, 7], num_heads = [3, 6, 12, 24], mlp_ratio = 4., drop_path_rate = 0.2)
    return model

@register_model
def Swin_DuSA_128_22182_base_224():
    model = DuSATransformer(embed_dim = 128, depth = [2, 2, 18, 2],
        split_size = [14, 14, 14, 7], num_heads = [4, 8, 16, 32], mlp_ratio = 4., drop_path_rate = 0.5)
    return model

@register_model
def Swin_DuSA_128_22182_base_384():
    model = DuSATransformer(img_size = 384, embed_dim = 128, depth = [2, 2, 18, 2],
        split_size = [24, 24, 24, 12], num_heads = [4, 8, 16, 32], mlp_ratio = 4., drop_path_rate = 0.5)
    return model