import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from timm.models import register_model
import math
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

class StemConv(nn.Module):
    def __init__(self, o_dim, act_layer=nn.GELU):
        super(StemConv, self).__init__()
        self.conv_1 = nn.Conv2d(3, o_dim // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv_2 = nn.Conv2d(o_dim // 2, o_dim, kernel_size=3, stride=2, padding=1, bias=False)
        self.act_1 = act_layer()
        self.act_2 = act_layer()
        self.norm_layer_1 = nn.BatchNorm2d(o_dim // 2)
        self.norm_layer_2 = nn.BatchNorm2d(o_dim)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.norm_layer_1(x)
        x = self.act_1(x)
        x = self.conv_2(x)
        x = self.norm_layer_2(x)
        x = self.act_2(x)  # if needed?
        return x

class DownSampling(nn.Module):
    def __init__(self, in_dim, o_dim):
        super(DownSampling, self).__init__()
        self.conv = nn.Conv2d(in_dim, o_dim, kernel_size=3, stride=2, padding=1)
        self.norm_layer = nn.BatchNorm2d(in_dim)

    def forward(self, x):
        return self.conv(self.norm_layer(x))

class InceptionLocalRep(nn.Module):
    def __init__(self, in_dim, dir_emb_rate=1 / 4):
        super(InceptionLocalRep, self).__init__()
        self.dir_dim = int(in_dim * dir_emb_rate)
        self.sqr_dim = int(in_dim - self.dir_dim * 2)

        self.conv_h = nn.Conv2d(self.dir_dim, self.dir_dim, kernel_size=(3, 1), stride=1, padding=(1, 0),
                                groups=self.dir_dim, bias=False)
        self.conv_w = nn.Conv2d(self.dir_dim, self.dir_dim, kernel_size=(1, 3), stride=1, padding=(0, 1),
                                groups=self.dir_dim, bias=False)
        self.conv_s = nn.Conv2d(self.sqr_dim, self.sqr_dim, kernel_size=7, stride=1, padding=3, groups=self.sqr_dim,
                                bias=False)

    def forward(self, x):
        x1, x2, x3 = torch.split(x, [self.dir_dim, self.dir_dim, self.sqr_dim], dim=1)
        x1, x2, x3 = self.conv_h(x1), self.conv_w(x2), self.conv_s(x3)
        x = torch.cat((x1, x2, x3), dim=1)
        return x


class SpatialConditionDecon(nn.Module):
    def __init__(self, in_dim, ratios, act_layer, pooling_proj=True, pooling_proj_rate=0.5):
        global embed_dim
        super(SpatialConditionDecon, self).__init__()
        self.iter = len(ratios)
        self.poolings = nn.ModuleList()
        self.pooling_proj = pooling_proj
        self.act = act_layer()

        if pooling_proj:
            embed_dim = int(in_dim * pooling_proj_rate)
            self.proj = nn.ModuleList()
            self.norm_layer = nn.LayerNorm(embed_dim)

        for i in range(self.iter):
            self.poolings.append(nn.AvgPool2d(kernel_size=(ratios[i][0], ratios[i][1])))
            if pooling_proj:
                self.proj.append(nn.Conv2d(in_dim, embed_dim, kernel_size=1, stride=1, ))

    def forward(self, x):
        B, C, H, W = x.shape
        pools = []

        if self.pooling_proj:
            for i in range(self.iter):
                pool = self.poolings[i](x)
                pool = self.proj[i](pool)
                pools.append(pool.view(B, C // 2, -1))
            pools = torch.cat(pools, dim=2).permute(0, 2, 1)  # B,N,C
            pools = self.norm_layer(pools)
        else:
            for i in range(self.iter):
                pools.append(self.poolings[i](x).view(B, C, -1))

        pools = self.act(pools)
        return pools


class SpatialCoordinationAtt(nn.Module):
    def __init__(self, in_dim, num_heads, pool_ratios, act_layer, attn_drop=0., proj_drop=0., qkv_bias=True,
                 pooling_proj=True,
                 pooling_proj_rate=0.5):
        super(SpatialCoordinationAtt, self).__init__()

        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.num_heads = num_heads
        self.head_dim = in_dim // num_heads
        if pooling_proj:
            embed_dim = int(in_dim * pooling_proj_rate)
            self.qk = nn.Linear(embed_dim, 2 * in_dim, bias=qkv_bias)
        else:
            self.qk = nn.Linear(in_dim, in_dim, bias=qkv_bias)

        self.v = nn.Sequential(nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding=0, bias=False),
                               nn.BatchNorm2d(in_dim))

        self.attn_drop = nn.Dropout(attn_drop)

        self.proj = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1)

        self.proj_drop = nn.Dropout(proj_drop)

        self.dconv = nn.Sequential(
            nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1, groups=in_dim, bias=False),
            nn.BatchNorm2d(in_dim),
            act_layer(),
        )

        self.SCD = SpatialConditionDecon(in_dim, pool_ratios, act_layer, pooling_proj, pooling_proj_rate)

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        qk = self.SCD(x)
        qk = self.qk(qk).reshape(B, -1, 2, self.num_heads, self.head_dim)
        qk = qk.permute(2, 0, 3, 1, 4)
        q, k = qk[0], qk[1]

        v = self.v(x)
        v_ = v

        v = v.reshape(B, C, N).permute(0, 2, 1)

        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q = q.transpose(-2, -1).contiguous()
        k = k.transpose(-2, -1).contiguous()
        v = v.transpose(-2, -1).contiguous()

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        x = (q @ k.transpose(-2, -1).contiguous()) * self.temperature
        x = x.softmax(dim=-1)
        x = self.attn_drop(x)

        x = (x @ v).reshape(B, C, H, W)
        x = x + x * self.dconv(v_)  # add -> gate
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'temperature'}


class SFFN(nn.Module):
    def __init__(self, in_dim, act_layer, exp_rate=4, drop_out_rate=0, mlp_bias=True, conv_mlp=False):
        super(SFFN, self).__init__()
        embed_dim = in_dim * exp_rate
        self.conv_exp = nn.Conv2d(in_dim, embed_dim, kernel_size=1, stride=1, bias=mlp_bias)
        self.conv_squ = nn.Conv2d(embed_dim, in_dim, kernel_size=1, stride=1, bias=mlp_bias)
        self.act = act_layer()
        self.drop = nn.Dropout(drop_out_rate) if drop_out_rate > 0 else nn.Identity()
        if conv_mlp:
            self.mid_conv = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, groups=embed_dim),
                # nn.BatchNorm2d(embed_dim), # severely slow down the inference....
                act_layer(),
            )
        else:
            self.mid_conv = nn.Identity()

    def forward(self, x):
        x = self.conv_exp(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.mid_conv(x)
        x = self.conv_squ(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self, in_dim, act_layer, exp_rate=4, dir_emb_rate=1 / 4, drop_path_rate=0., type=0, num_heads=4,
                 pool_ratios=[[1, 3], [3, 1], [1, 1]], is_first=False, mlp_bias=True, use_ConvMLP=False):
        super(Block, self).__init__()

        if not type:
            self.DBlock = InceptionLocalRep(in_dim, dir_emb_rate)
        else:
            self.DBlock = SpatialCoordinationAtt(in_dim, num_heads, pool_ratios=pool_ratios, act_layer=act_layer)

        self.norm_layer_1 = nn.Identity() if is_first else nn.BatchNorm2d(in_dim)
        self.norm_layer_2 = nn.BatchNorm2d(in_dim)

        self.PBlock = SFFN(in_dim, act_layer, exp_rate, drop_out_rate=0, mlp_bias=mlp_bias, conv_mlp=use_ConvMLP)

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        # pre_norm for fast convergence
        x = x + self.drop_path(self.DBlock(self.norm_layer_1(x)))
        x = x + self.drop_path(self.PBlock(self.norm_layer_2(x)))
        return x


class SCFormer(nn.Module):
    def __init__(self, num_classes=1000,
                 drop_path_rate=0.,
                 depths=[2, 2, 6, 2],
                 dims=[32, 64, 160, 256],
                 exp_rate=[8, 8, 4, 4],
                 dir_emb_rate=[1 / 4, 1 / 8, 1 / 16, 1 / 32],
                 num_heads=[4, 4, 4, 4],
                 cmlp_idx=2,
                 act_layer=nn.GELU,
                 init_weight='trunc',
                 mlp_bias=True,
                 distillation=False,
                 **kwargs):

        super(SCFormer, self).__init__()
        assert init_weight in ['trunc', 'basic']
        assert drop_path_rate >= 0 and drop_path_rate <= 1
        self.downsample_layers = nn.ModuleList()
        self.downsample_layers.append(StemConv(o_dim=dims[0], act_layer=act_layer))
        # self.downsample_layers.append(Stem(o_dim=dims[0]))
        for i in range(3):
            self.downsample_layers.append(DownSampling(in_dim=dims[i], o_dim=dims[i + 1]))
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        self.stages = nn.ModuleList()

        # pooling ratios in SCA for each stage
        pool_sizes = [[[12, 6], [6, 12], [8, 8]],  # s1
                      [[8, 4], [4, 8], [4, 4]],  # s2
                      [[6, 3], [3, 6], [2, 2]],  # s3
                      [[3, 1], [1, 3], [1, 1]]]  # s4

        for i in range(4):
            stage = nn.Sequential(
                *[Block(in_dim=dims[i],
                        act_layer=act_layer,
                        exp_rate=exp_rate[i],
                        dir_emb_rate=dir_emb_rate[i],
                        drop_path_rate=dp_rates[cur + j],
                        type=((j + 1) % 2 == 0),
                        num_heads=num_heads[i],
                        pool_ratios=pool_sizes[i],
                        is_first=j == 0 and i == 0,
                        mlp_bias=mlp_bias,
                        use_ConvMLP=(i < cmlp_idx))
                  for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.dist = distillation
        if self.dist:
            self.dist_head = nn.Linear(
                dims[-1], num_classes) if num_classes > 0 \
                else nn.Identity()

        self.head = nn.Linear(dims[-1], num_classes)
        self.norm_layer = nn.LayerNorm(dims[-1])  # Final LayerNorm

        if init_weight == 'trunc':
            self.apply(self._init_weights_trunc)
        else:
            self._init_weights_std()

    def _init_weights_std(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def _init_weights_trunc(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # four stages
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        x = x.mean([-2, -1])  # global avg pooling
        x = self.norm_layer(x)  # final norm_layer

        if self.dist:
            x = self.head(x), self.dist_head(x)
            if not self.training:
                x = (x[0] + x[1]) / 2
        else:
            x = self.head(x)  # cls
        return x


def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .95, 'interpolation': 'bicubic',
        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
        'classifier': 'head',
        **kwargs
    }



@register_model
def SCFormer_XXS(num_classes=1000, drop_path_rate=0, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[2, 2, 4, 2],
                   dims=[24, 48, 120, 192],
                   exp_rate=[8, 8, 4, 4],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=3,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model



@register_model
def SCFormer_XS(num_classes=1000, drop_path_rate=0, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[2, 2, 6, 2],
                   dims=[32, 64, 160, 256],
                   exp_rate=[8, 8, 4, 4],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=2,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model



@register_model
def SCFormer_S(num_classes=1000, drop_path_rate=0, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[2, 2, 8, 2],
                   dims=[40, 80, 200, 320],
                   exp_rate=[4, 4, 4, 4],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=2,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model



@register_model
def SCFormer_M(num_classes=1000, drop_path_rate=0, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[2, 2, 10, 4],
                   dims=[48, 96, 200, 384],
                   exp_rate=[4, 4, 4, 4],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=2,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model


@register_model
def SCFormer_ML(num_classes=1000, drop_path_rate=0.1, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[2, 4, 12, 4],
                   dims=[64, 128, 300, 512],
                   exp_rate=[4, 4, 4, 3],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=2,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model


@register_model
def SCFormer_L(num_classes=1000, drop_path_rate=0.1, pretrained=False, **kwargs):
    model = SCFormer(num_classes=num_classes,
                   drop_path_rate=drop_path_rate,
                   depths=[4, 4, 16, 4],
                   dims=[72, 144, 320, 512],
                   exp_rate=[4, 4, 4, 4],
                   dir_emb_rate=[1 / 4, 1 / 4, 1 / 4, 1 / 4],
                   num_heads=[1, 2, 5, 8],
                   act_layer=nn.GELU,
                   init_weight='trunc',
                   cmlp_idx=2,
                   mlp_bias=True,
                   **kwargs)
    model.default_cfg = _cfg(crop_pct=224 / 256)
    return model