# Copyright (c) OpenMMLab. All rights reserved.
import itertools

import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, fuse_conv_bn
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule, ModuleList, Sequential

from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer


class HybridBackbone(BaseModule):

    def __init__(
            self,
            embed_dim,
            kernel_size=3,
            stride=2,
            pad=1,
            dilation=1,
            groups=1,
            act_cfg=dict(type='HSwish'),
            conv_cfg=None,
            norm_cfg=dict(type='BN'),
            init_cfg=None,
    ):
        super(HybridBackbone, self).__init__(init_cfg=init_cfg)

        self.input_channels = [
            3, embed_dim // 8, embed_dim // 4, embed_dim // 2
        ]
        self.output_channels = [
            embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim
        ]
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg

        self.patch_embed = Sequential()

        for i in range(len(self.input_channels)):
            conv_bn = ConvolutionBatchNorm(
                self.input_channels[i],
                self.output_channels[i],
                kernel_size=kernel_size,
                stride=stride,
                pad=pad,
                dilation=dilation,
                groups=groups,
                norm_cfg=norm_cfg,
            )
            self.patch_embed.add_module('%d' % (2 * i), conv_bn)
            if i < len(self.input_channels) - 1:
                self.patch_embed.add_module('%d' % (i * 2 + 1),
                                            build_activation_layer(act_cfg))

    def forward(self, x):
        x = self.patch_embed(x)
        return x


class ConvolutionBatchNorm(BaseModule):

    def __init__(
            self,
            in_channel,
            out_channel,
            kernel_size=3,
            stride=2,
            pad=1,
            dilation=1,
            groups=1,
            norm_cfg=dict(type='BN'),
    ):
        super(ConvolutionBatchNorm, self).__init__()
        self.conv = nn.Conv2d(
            in_channel,
            out_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=pad,
            dilation=dilation,
            groups=groups,
            bias=False)
        self.bn = build_norm_layer(norm_cfg, out_channel)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

    @torch.no_grad()
    def fuse(self):
        return fuse_conv_bn(self).conv


class LinearBatchNorm(BaseModule):

    def __init__(self, in_feature, out_feature, norm_cfg=dict(type='BN1d')):
        super(LinearBatchNorm, self).__init__()
        self.linear = nn.Linear(in_feature, out_feature, bias=False)
        self.bn = build_norm_layer(norm_cfg, out_feature)

    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x.flatten(0, 1)).reshape_as(x)
        return x

    @torch.no_grad()
    def fuse(self):
        w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
        w = self.linear.weight * w[:, None]
        b = self.bn.bias - self.bn.running_mean * self.bn.weight / \
            (self.bn.running_var + self.bn.eps) ** 0.5

        factory_kwargs = {
            'device': self.linear.weight.device,
            'dtype': self.linear.weight.dtype
        }
        bias = nn.Parameter(
            torch.empty(self.linear.out_features, **factory_kwargs))
        self.linear.register_parameter('bias', bias)
        self.linear.weight.data.copy_(w)
        self.linear.bias.data.copy_(b)
        return self.linear


class Residual(BaseModule):

    def __init__(self, block, drop_path_rate=0.):
        super(Residual, self).__init__()
        self.block = block
        if drop_path_rate > 0:
            self.drop_path = DropPath(drop_path_rate)
        else:
            self.drop_path = nn.Identity()

    def forward(self, x):
        x = x + self.drop_path(self.block(x))
        return x


class Attention(BaseModule):

    def __init__(
            self,
            dim,
            key_dim,
            num_heads=8,
            attn_ratio=4,
            act_cfg=dict(type='HSwish'),
            resolution=14,
    ):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio
        h = self.dh + nh_kd * 2
        self.qkv = LinearBatchNorm(dim, h)
        self.proj = nn.Sequential(
            build_activation_layer(act_cfg), LinearBatchNorm(self.dh, dim))

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N, N))

    @torch.no_grad()
    def train(self, mode=True):
        """change the mode of model."""
        super(Attention, self).train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):  # x (B,N,C)
        B, N, C = x.shape  # 2 196 128
        qkv = self.qkv(x)  # 2 196 128
        q, k, v = qkv.view(B, N, self.num_heads, -1).split(
            [self.key_dim, self.key_dim, self.d],
            dim=3)  # q 2 196 4 16 ; k 2 196 4 16; v 2 196 4 32
        q = q.permute(0, 2, 1, 3)  # 2 4 196 16
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn = ((q @ k.transpose(-2, -1)) *
                self.scale  # 2 4 196 16 * 2 4 16 196 -> 2 4 196 196
                + (self.attention_biases[:, self.attention_bias_idxs]
                   if self.training else self.ab))
        attn = attn.softmax(dim=-1)  # 2 4 196 196 -> 2 4 196 196
        x = (attn @ v).transpose(1, 2).reshape(
            B, N,
            self.dh)  # 2 4 196 196 * 2 4 196 32 -> 2 4 196 32 -> 2 196 128
        x = self.proj(x)
        return x


class MLP(nn.Sequential):

    def __init__(self, embed_dim, mlp_ratio, act_cfg=dict(type='HSwish')):
        super(MLP, self).__init__()
        h = embed_dim * mlp_ratio
        self.linear1 = LinearBatchNorm(embed_dim, h)
        self.activation = build_activation_layer(act_cfg)
        self.linear2 = LinearBatchNorm(h, embed_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x


class Subsample(BaseModule):

    def __init__(self, stride, resolution):
        super(Subsample, self).__init__()
        self.stride = stride
        self.resolution = resolution

    def forward(self, x):
        B, _, C = x.shape
        # B, N, C -> B, H, W, C
        x = x.view(B, self.resolution, self.resolution, C)
        x = x[:, ::self.stride, ::self.stride]
        x = x.reshape(B, -1, C)  # B, H', W', C -> B, N', C
        return x


class AttentionSubsample(nn.Sequential):

    def __init__(self,
                 in_dim,
                 out_dim,
                 key_dim,
                 num_heads=8,
                 attn_ratio=2,
                 act_cfg=dict(type='HSwish'),
                 stride=2,
                 resolution=14):
        super(AttentionSubsample, self).__init__()
        self.num_heads = num_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * self.num_heads
        self.attn_ratio = attn_ratio
        self.sub_resolution = (resolution - 1) // stride + 1
        h = self.dh + nh_kd
        self.kv = LinearBatchNorm(in_dim, h)

        self.q = nn.Sequential(
            Subsample(stride, resolution), LinearBatchNorm(in_dim, nh_kd))
        self.proj = nn.Sequential(
            build_activation_layer(act_cfg), LinearBatchNorm(self.dh, out_dim))

        self.stride = stride
        self.resolution = resolution
        points = list(itertools.product(range(resolution), range(resolution)))
        sub_points = list(
            itertools.product(
                range(self.sub_resolution), range(self.sub_resolution)))
        N = len(points)
        N_sub = len(sub_points)
        attention_offsets = {}
        idxs = []
        for p1 in sub_points:
            for p2 in points:
                size = 1
                offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
                          abs(p1[1] * stride - p2[1] + (size - 1) / 2))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N_sub, N))

    @torch.no_grad()
    def train(self, mode=True):
        super(AttentionSubsample, self).train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]

    def forward(self, x):
        B, N, C = x.shape
        k, v = self.kv(x).view(B, N, self.num_heads,
                               -1).split([self.key_dim, self.d], dim=3)
        k = k.permute(0, 2, 1, 3)  # BHNC
        v = v.permute(0, 2, 1, 3)  # BHNC
        q = self.q(x).view(B, self.sub_resolution**2, self.num_heads,
                           self.key_dim).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)) * self.scale + \
               (self.attention_biases[:, self.attention_bias_idxs]
                if self.training else self.ab)
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
        x = self.proj(x)
        return x


@MODELS.register_module()
class LeViT(BaseBackbone):
    """LeViT backbone.

    A PyTorch implementation of `LeViT: A Vision Transformer in ConvNet's
    Clothing for Faster Inference <https://arxiv.org/abs/2104.01136>`_

    Modified from the official implementation:
    https://github.com/facebookresearch/LeViT

    Args:
        arch (str | dict): LeViT architecture.

            If use string, choose from '128s', '128', '192', '256' and '384'.
            If use dict, it should have below keys:

            - **embed_dims** (List[int]): The embed dimensions of each stage.
            - **key_dims** (List[int]): The embed dimensions of the key in the
              attention layers of each stage.
            - **num_heads** (List[int]): The number of heads in each stage.
            - **depths** (List[int]): The number of blocks in each stage.

        img_size (int): Input image size
        patch_size (int | tuple): The patch size. Deault to 16
        attn_ratio (int): Ratio of hidden dimensions of the value in attention
            layers. Defaults to 2.
        mlp_ratio (int): Ratio of hidden dimensions in MLP layers.
            Defaults to 2.
        act_cfg (dict): The config of activation functions.
            Defaults to ``dict(type='HSwish')``.
        hybrid_backbone (callable): A callable object to build the patch embed
            module. Defaults to use :class:`HybridBackbone`.
        out_indices (Sequence | int): Output from which stages.
            Defaults to -1, means the last stage.
        deploy (bool): Whether to switch the model structure to
            deployment mode. Defaults to False.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to None.
    """
    arch_zoo = {
        '128s': {
            'embed_dims': [128, 256, 384],
            'num_heads': [4, 6, 8],
            'depths': [2, 3, 4],
            'key_dims': [16, 16, 16],
        },
        '128': {
            'embed_dims': [128, 256, 384],
            'num_heads': [4, 8, 12],
            'depths': [4, 4, 4],
            'key_dims': [16, 16, 16],
        },
        '192': {
            'embed_dims': [192, 288, 384],
            'num_heads': [3, 5, 6],
            'depths': [4, 4, 4],
            'key_dims': [32, 32, 32],
        },
        '256': {
            'embed_dims': [256, 384, 512],
            'num_heads': [4, 6, 8],
            'depths': [4, 4, 4],
            'key_dims': [32, 32, 32],
        },
        '384': {
            'embed_dims': [384, 512, 768],
            'num_heads': [6, 9, 12],
            'depths': [4, 4, 4],
            'key_dims': [32, 32, 32],
        },
    }

    def __init__(self,
                 arch,
                 img_size=224,
                 patch_size=16,
                 attn_ratio=2,
                 mlp_ratio=2,
                 act_cfg=dict(type='HSwish'),
                 hybrid_backbone=HybridBackbone,
                 out_indices=-1,
                 deploy=False,
                 drop_path_rate=0,
                 init_cfg=None):
        super(LeViT, self).__init__(init_cfg=init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(self.arch_zoo), \
                f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
            self.arch = self.arch_zoo[arch]
        elif isinstance(arch, dict):
            essential_keys = {'embed_dim', 'num_heads', 'depth', 'key_dim'}
            assert isinstance(arch, dict) and set(arch) == essential_keys, \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch = arch
        else:
            raise TypeError('Expect "arch" to be either a string '
                            f'or a dict, got {type(arch)}')

        self.embed_dims = self.arch['embed_dims']
        self.num_heads = self.arch['num_heads']
        self.key_dims = self.arch['key_dims']
        self.depths = self.arch['depths']
        self.num_stages = len(self.embed_dims)
        self.drop_path_rate = drop_path_rate

        self.patch_embed = hybrid_backbone(self.embed_dims[0])

        self.resolutions = []
        resolution = img_size // patch_size
        self.stages = ModuleList()
        for i, (embed_dims, key_dims, depth, num_heads) in enumerate(
                zip(self.embed_dims, self.key_dims, self.depths,
                    self.num_heads)):
            blocks = []
            if i > 0:
                downsample = AttentionSubsample(
                    in_dim=self.embed_dims[i - 1],
                    out_dim=embed_dims,
                    key_dim=key_dims,
                    num_heads=self.embed_dims[i - 1] // key_dims,
                    attn_ratio=4,
                    act_cfg=act_cfg,
                    stride=2,
                    resolution=resolution)
                blocks.append(downsample)
                resolution = downsample.sub_resolution
                if mlp_ratio > 0:  # mlp_ratio
                    blocks.append(
                        Residual(
                            MLP(embed_dims, mlp_ratio, act_cfg=act_cfg),
                            self.drop_path_rate))
            self.resolutions.append(resolution)
            for _ in range(depth):
                blocks.append(
                    Residual(
                        Attention(
                            embed_dims,
                            key_dims,
                            num_heads,
                            attn_ratio=attn_ratio,
                            act_cfg=act_cfg,
                            resolution=resolution,
                        ), self.drop_path_rate))
                if mlp_ratio > 0:
                    blocks.append(
                        Residual(
                            MLP(embed_dims, mlp_ratio, act_cfg=act_cfg),
                            self.drop_path_rate))

            self.stages.append(Sequential(*blocks))

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        elif isinstance(out_indices, tuple):
            out_indices = list(out_indices)
        elif not isinstance(out_indices, list):
            raise TypeError('"out_indices" must by a list, tuple or int, '
                            f'get {type(out_indices)} instead.')
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.num_stages + index
            assert 0 <= out_indices[i] < self.num_stages, \
                f'Invalid out_indices {index}.'
        self.out_indices = out_indices

        self.deploy = False
        if deploy:
            self.switch_to_deploy()

    def switch_to_deploy(self):
        if self.deploy:
            return
        fuse_parameters(self)
        self.deploy = True

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)  # B, C, H, W -> B, L, C
        outs = []
        for i, stage in enumerate(self.stages):
            x = stage(x)
            B, _, C = x.shape
            if i in self.out_indices:
                out = x.reshape(B, self.resolutions[i], self.resolutions[i], C)
                out = out.permute(0, 3, 1, 2).contiguous()
                outs.append(out)

        return tuple(outs)


def fuse_parameters(module):
    for child_name, child in module.named_children():
        if hasattr(child, 'fuse'):
            setattr(module, child_name, child.fuse())
        else:
            fuse_parameters(child)
