# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone


def conv_bn(in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            groups,
            dilation=1,
            norm_cfg=dict(type='BN')):
    """Construct a sequential conv and bn.

    Args:
        in_channels (int): Dimension of input features.
        out_channels (int): Dimension of output features.
        kernel_size (int): kernel_size of the convolution.
        stride (int): stride of the convolution.
        padding (int): stride of the convolution.
        groups (int): groups of the convolution.
        dilation (int): dilation of the convolution. Default to 1.
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default to  ``dict(type='BN', requires_grad=True)``.

    Returns:
        nn.Sequential(): A conv layer and a batch norm layer.
    """
    if padding is None:
        padding = kernel_size // 2
    result = nn.Sequential()
    result.add_module(
        'conv',
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False))
    result.add_module('bn', build_norm_layer(norm_cfg, out_channels)[1])
    return result


def conv_bn_relu(in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 groups,
                 dilation=1):
    """Construct a sequential conv, bn and relu.

    Args:
        in_channels (int): Dimension of input features.
        out_channels (int): Dimension of output features.
        kernel_size (int): kernel_size of the convolution.
        stride (int): stride of the convolution.
        padding (int): stride of the convolution.
        groups (int): groups of the convolution.
        dilation (int): dilation of the convolution. Default to 1.

    Returns:
        nn.Sequential(): A conv layer, batch norm layer and a relu function.
    """

    if padding is None:
        padding = kernel_size // 2
    result = conv_bn(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        groups=groups,
        dilation=dilation)
    result.add_module('nonlinear', nn.ReLU())
    return result


def fuse_bn(conv, bn):
    """Fuse the parameters in a branch with a conv and bn.

    Args:
        conv (nn.Conv2d): The convolution module to fuse.
        bn (nn.BatchNorm2d): The batch normalization to fuse.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
        fusing the parameters of conv and bn in one branch.
        The first element is the weight and the second is the bias.
    """
    kernel = conv.weight
    running_mean = bn.running_mean
    running_var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)
    return kernel * t, beta - running_mean * gamma / std


class ReparamLargeKernelConv(BaseModule):
    """Super large kernel implemented by with large convolutions.

    Input: Tensor with shape [B, C, H, W].
    Output: Tensor with shape [B, C, H, W].

    Args:
        in_channels (int): Dimension of input features.
        out_channels (int): Dimension of output features.
        kernel_size (int): kernel_size of the large convolution.
        stride (int): stride of the large convolution.
        groups (int): groups of the large convolution.
        small_kernel (int): kernel_size of the small convolution.
        small_kernel_merged (bool): Whether to switch the model structure to
            deployment mode (merge the small kernel to the large kernel).
            Default to  False.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to None
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 groups,
                 small_kernel,
                 small_kernel_merged=False,
                 init_cfg=None):
        super(ReparamLargeKernelConv, self).__init__(init_cfg)
        self.kernel_size = kernel_size
        self.small_kernel = small_kernel
        self.small_kernel_merged = small_kernel_merged
        # We assume the conv does not change the feature map size,
        # so padding = k//2.
        # Otherwise, you may configure padding as you wish,
        # and change the padding of small_conv accordingly.
        padding = kernel_size // 2
        if small_kernel_merged:
            self.lkb_reparam = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=1,
                groups=groups,
                bias=True)
        else:
            self.lkb_origin = conv_bn(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=1,
                groups=groups)
            if small_kernel is not None:
                assert small_kernel <= kernel_size
                self.small_conv = conv_bn(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=small_kernel,
                    stride=stride,
                    padding=small_kernel // 2,
                    groups=groups,
                    dilation=1)

    def forward(self, inputs):
        if hasattr(self, 'lkb_reparam'):
            out = self.lkb_reparam(inputs)
        else:
            out = self.lkb_origin(inputs)
            if hasattr(self, 'small_conv'):
                out += self.small_conv(inputs)
        return out

    def get_equivalent_kernel_bias(self):
        eq_k, eq_b = fuse_bn(self.lkb_origin.conv, self.lkb_origin.bn)
        if hasattr(self, 'small_conv'):
            small_k, small_b = fuse_bn(self.small_conv.conv,
                                       self.small_conv.bn)
            eq_b += small_b
            #   add to the central part
            eq_k += nn.functional.pad(
                small_k, [(self.kernel_size - self.small_kernel) // 2] * 4)
        return eq_k, eq_b

    def merge_kernel(self):
        """Switch the model structure from training mode to deployment mode."""
        if self.small_kernel_merged:
            return
        eq_k, eq_b = self.get_equivalent_kernel_bias()
        self.lkb_reparam = nn.Conv2d(
            in_channels=self.lkb_origin.conv.in_channels,
            out_channels=self.lkb_origin.conv.out_channels,
            kernel_size=self.lkb_origin.conv.kernel_size,
            stride=self.lkb_origin.conv.stride,
            padding=self.lkb_origin.conv.padding,
            dilation=self.lkb_origin.conv.dilation,
            groups=self.lkb_origin.conv.groups,
            bias=True)

        self.lkb_reparam.weight.data = eq_k
        self.lkb_reparam.bias.data = eq_b
        self.__delattr__('lkb_origin')
        if hasattr(self, 'small_conv'):
            self.__delattr__('small_conv')

        self.small_kernel_merged = True


class ConvFFN(BaseModule):
    """Mlp implemented by with 1*1 convolutions.

    Input: Tensor with shape [B, C, H, W].
    Output: Tensor with shape [B, C, H, W].

    Args:
        in_channels (int): Dimension of input features.
        internal_channels (int): Dimension of hidden features.
        out_channels (int): Dimension of output features.
        drop_path (float): Stochastic depth rate. Defaults to 0.
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default to  ``dict(type='BN', requires_grad=True)``.
        act_cfg (dict): The config dict for activation between pointwise
            convolution. Defaults to ``dict(type='GELU')``.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 in_channels,
                 internal_channels,
                 out_channels,
                 drop_path,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='GELU'),
                 init_cfg=None):
        super(ConvFFN, self).__init__(init_cfg)
        self.drop_path = DropPath(
            drop_prob=drop_path) if drop_path > 0. else nn.Identity()
        self.preffn_bn = build_norm_layer(norm_cfg, in_channels)[1]
        self.pw1 = conv_bn(
            in_channels=in_channels,
            out_channels=internal_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1)
        self.pw2 = conv_bn(
            in_channels=internal_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1)
        self.nonlinear = build_activation_layer(act_cfg)

    def forward(self, x):
        out = self.preffn_bn(x)
        out = self.pw1(out)
        out = self.nonlinear(out)
        out = self.pw2(out)
        return x + self.drop_path(out)


class RepLKBlock(BaseModule):
    """RepLKBlock for RepLKNet backbone.

    Args:
        in_channels (int): The input channels of the block.
        dw_channels (int): The intermediate channels of the block,
            i.e., input channels of the large kernel convolution.
        block_lk_size (int): size of the super large kernel. Defaults: 31.
        small_kernel (int): size of the parallel small kernel. Defaults: 5.
        drop_path (float): Stochastic depth rate. Defaults: 0.
        small_kernel_merged (bool): Whether to switch the model structure to
            deployment mode (merge the small kernel to the large kernel).
            Default to  False.
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default to  ``dict(type='BN', requires_grad=True)``.
        act_cfg (dict): Config dict for activation layer.
            Default to  ``dict(type='ReLU')``.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default to  None
    """

    def __init__(self,
                 in_channels,
                 dw_channels,
                 block_lk_size,
                 small_kernel,
                 drop_path,
                 small_kernel_merged=False,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 init_cfg=None):
        super(RepLKBlock, self).__init__(init_cfg)
        self.pw1 = conv_bn_relu(in_channels, dw_channels, 1, 1, 0, groups=1)
        self.pw2 = conv_bn(dw_channels, in_channels, 1, 1, 0, groups=1)
        self.large_kernel = ReparamLargeKernelConv(
            in_channels=dw_channels,
            out_channels=dw_channels,
            kernel_size=block_lk_size,
            stride=1,
            groups=dw_channels,
            small_kernel=small_kernel,
            small_kernel_merged=small_kernel_merged)
        self.lk_nonlinear = build_activation_layer(act_cfg)
        self.prelkb_bn = build_norm_layer(norm_cfg, in_channels)[1]
        self.drop_path = DropPath(
            drop_prob=drop_path) if drop_path > 0. else nn.Identity()
        # print('drop path:', self.drop_path)

    def forward(self, x):
        out = self.prelkb_bn(x)
        out = self.pw1(out)
        out = self.large_kernel(out)
        out = self.lk_nonlinear(out)
        out = self.pw2(out)
        return x + self.drop_path(out)


class RepLKNetStage(BaseModule):
    """
    generate RepLKNet blocks for a stage
    return: RepLKNet blocks

    Args:
        channels (int): The input channels of the stage.
        num_blocks (int): The number of blocks of the stage.
        stage_lk_size (int): size of the super large kernel. Defaults: 31.
        drop_path (float): Stochastic depth rate. Defaults: 0.
        small_kernel (int): size of the parallel small kernel. Defaults: 5.
        dw_ratio (float): The intermediate channels
            expansion ratio of the block. Defaults: 1.
        ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default to  False.
        small_kernel_merged (bool): Whether to switch the model structure to
            deployment mode (merge the small kernel to the large kernel).
            Default to  False.
        norm_intermediate_features (bool): Construct and config norm layer
            or not.
            Using True will normalize the intermediate features for
            downstream dense prediction tasks.
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default to  ``dict(type='BN', requires_grad=True)``.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default to  None
    """

    def __init__(
            self,
            channels,
            num_blocks,
            stage_lk_size,
            drop_path,
            small_kernel,
            dw_ratio=1,
            ffn_ratio=4,
            with_cp=False,  # train with torch.utils.checkpoint to save memory
            small_kernel_merged=False,
            norm_intermediate_features=False,
            norm_cfg=dict(type='BN'),
            init_cfg=None):
        super(RepLKNetStage, self).__init__(init_cfg)
        self.with_cp = with_cp
        blks = []
        for i in range(num_blocks):
            block_drop_path = drop_path[i] if isinstance(drop_path,
                                                         list) else drop_path
            #   Assume all RepLK Blocks within a stage share the same lk_size.
            #   You may tune it on your own model.
            replk_block = RepLKBlock(
                in_channels=channels,
                dw_channels=int(channels * dw_ratio),
                block_lk_size=stage_lk_size,
                small_kernel=small_kernel,
                drop_path=block_drop_path,
                small_kernel_merged=small_kernel_merged)
            convffn_block = ConvFFN(
                in_channels=channels,
                internal_channels=int(channels * ffn_ratio),
                out_channels=channels,
                drop_path=block_drop_path)
            blks.append(replk_block)
            blks.append(convffn_block)
        self.blocks = nn.ModuleList(blks)
        if norm_intermediate_features:
            self.norm = build_norm_layer(norm_cfg, channels)[1]
        else:
            self.norm = nn.Identity()

    def forward(self, x):
        for blk in self.blocks:
            if self.with_cp:
                x = checkpoint.checkpoint(blk, x)  # Save training memory
            else:
                x = blk(x)
        return x


@MODELS.register_module()
class RepLKNet(BaseBackbone):
    """RepLKNet backbone.

    A PyTorch impl of :
    `Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs
    <https://arxiv.org/abs/2203.06717>`_

    Args:
        arch (str | dict): The parameter of RepLKNet.
            If it's a dict, it should contain the following keys:

            - large_kernel_sizes (Sequence[int]):
                Large kernel size in each stage.
            - layers (Sequence[int]): Number of blocks in each stage.
            - channels (Sequence[int]): Number of channels in each stage.
            - small_kernel (int): size of the parallel small kernel.
            - dw_ratio (float): The intermediate channels
                expansion ratio of the block.
        in_channels (int): Number of input image channels. Default to  3.
        ffn_ratio (float): Mlp expansion ratio. Defaults to 4.
        out_indices (Sequence[int]): Output from which stages.
            Default to  (3, ).
        strides (Sequence[int]): Strides of the first block of each stage.
            Default to  (2, 2, 2, 2).
        dilations (Sequence[int]): Dilation of each stage.
            Default to  (1, 1, 1, 1).
        frozen_stages (int): Stages to be frozen
            (all param fixed). -1 means not freezing any parameters.
            Default to  -1.
        conv_cfg (dict | None): The config dict for conv layers.
            Default to None.
        norm_cfg (dict): The config dict for norm layers.
            Default to  ``dict(type='BN')``.
        act_cfg (dict): Config dict for activation layer.
            Default to  ``dict(type='ReLU')``.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default to False.
        deploy (bool): Whether to switch the model structure to deployment
            mode. Default to False.
        norm_intermediate_features (bool): Construct and
            config norm layer or not.
            Using True will normalize the intermediate features
            for downstream dense prediction tasks.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Default to False.
        init_cfg (dict or list[dict], optional): Initialization config dict.
    """

    arch_settings = {
        '31B':
        dict(
            large_kernel_sizes=[31, 29, 27, 13],
            layers=[2, 2, 18, 2],
            channels=[128, 256, 512, 1024],
            small_kernel=5,
            dw_ratio=1),
        '31L':
        dict(
            large_kernel_sizes=[31, 29, 27, 13],
            layers=[2, 2, 18, 2],
            channels=[192, 384, 768, 1536],
            small_kernel=5,
            dw_ratio=1),
        'XL':
        dict(
            large_kernel_sizes=[27, 27, 27, 13],
            layers=[2, 2, 18, 2],
            channels=[256, 512, 1024, 2048],
            small_kernel=None,
            dw_ratio=1.5),
    }

    def __init__(self,
                 arch,
                 in_channels=3,
                 ffn_ratio=4,
                 out_indices=(3, ),
                 strides=(2, 2, 2, 2),
                 dilations=(1, 1, 1, 1),
                 frozen_stages=-1,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 with_cp=False,
                 drop_path_rate=0.3,
                 small_kernel_merged=False,
                 norm_intermediate_features=False,
                 norm_eval=False,
                 init_cfg=[
                     dict(type='Kaiming', layer=['Conv2d']),
                     dict(
                         type='Constant',
                         val=1,
                         layer=['_BatchNorm', 'GroupNorm'])
                 ]):
        super(RepLKNet, self).__init__(init_cfg)

        if isinstance(arch, str):
            assert arch in self.arch_settings, \
                f'"arch": "{arch}" is not one of the arch_settings'
            arch = self.arch_settings[arch]
        elif not isinstance(arch, dict):
            raise TypeError('Expect "arch" to be either a string '
                            f'or a dict, got {type(arch)}')

        assert len(arch['layers']) == len(
            arch['channels']) == len(strides) == len(dilations)
        assert max(out_indices) < len(arch['layers'])

        self.arch = arch
        self.in_channels = in_channels
        self.out_indices = out_indices
        self.strides = strides
        self.dilations = dilations
        self.frozen_stages = frozen_stages
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.with_cp = with_cp
        self.drop_path_rate = drop_path_rate
        self.small_kernel_merged = small_kernel_merged
        self.norm_eval = norm_eval
        self.norm_intermediate_features = norm_intermediate_features

        self.out_indices = out_indices

        base_width = self.arch['channels'][0]
        self.norm_intermediate_features = norm_intermediate_features
        self.num_stages = len(self.arch['layers'])
        self.stem = nn.ModuleList([
            conv_bn_relu(
                in_channels=in_channels,
                out_channels=base_width,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=1),
            conv_bn_relu(
                in_channels=base_width,
                out_channels=base_width,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=base_width),
            conv_bn_relu(
                in_channels=base_width,
                out_channels=base_width,
                kernel_size=1,
                stride=1,
                padding=0,
                groups=1),
            conv_bn_relu(
                in_channels=base_width,
                out_channels=base_width,
                kernel_size=3,
                stride=2,
                padding=1,
                groups=base_width)
        ])
        # stochastic depth. We set block-wise drop-path rate.
        # The higher level blocks are more likely to be dropped.
        # This implementation follows Swin.
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate,
                                             sum(self.arch['layers']))
        ]
        self.stages = nn.ModuleList()
        self.transitions = nn.ModuleList()
        for stage_idx in range(self.num_stages):
            layer = RepLKNetStage(
                channels=self.arch['channels'][stage_idx],
                num_blocks=self.arch['layers'][stage_idx],
                stage_lk_size=self.arch['large_kernel_sizes'][stage_idx],
                drop_path=dpr[sum(self.arch['layers'][:stage_idx]
                                  ):sum(self.arch['layers'][:stage_idx + 1])],
                small_kernel=self.arch['small_kernel'],
                dw_ratio=self.arch['dw_ratio'],
                ffn_ratio=ffn_ratio,
                with_cp=with_cp,
                small_kernel_merged=small_kernel_merged,
                norm_intermediate_features=(stage_idx in out_indices))
            self.stages.append(layer)
            if stage_idx < len(self.arch['layers']) - 1:
                transition = nn.Sequential(
                    conv_bn_relu(
                        self.arch['channels'][stage_idx],
                        self.arch['channels'][stage_idx + 1],
                        1,
                        1,
                        0,
                        groups=1),
                    conv_bn_relu(
                        self.arch['channels'][stage_idx + 1],
                        self.arch['channels'][stage_idx + 1],
                        3,
                        stride=2,
                        padding=1,
                        groups=self.arch['channels'][stage_idx + 1]))
                self.transitions.append(transition)

    def forward_features(self, x):
        x = self.stem[0](x)
        for stem_layer in self.stem[1:]:
            if self.with_cp:
                x = checkpoint.checkpoint(stem_layer, x)  # save memory
            else:
                x = stem_layer(x)

        #   Need the intermediate feature maps
        outs = []
        for stage_idx in range(self.num_stages):
            x = self.stages[stage_idx](x)
            if stage_idx in self.out_indices:
                outs.append(self.stages[stage_idx].norm(x))
                # For RepLKNet-XL normalize the features
                # before feeding them into the heads
            if stage_idx < self.num_stages - 1:
                x = self.transitions[stage_idx](x)
        return outs

    def forward(self, x):
        x = self.forward_features(x)
        return tuple(x)

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.stem.eval()
            for param in self.stem.parameters():
                param.requires_grad = False
        for i in range(self.frozen_stages):
            stage = self.stages[i]
            stage.eval()
            for param in stage.parameters():
                param.requires_grad = False

    def train(self, mode=True):
        super(RepLKNet, self).train(mode)
        self._freeze_stages()
        if mode and self.norm_eval:
            for m in self.modules():
                if isinstance(m, _BatchNorm):
                    m.eval()

    def switch_to_deploy(self):
        for m in self.modules():
            if hasattr(m, 'merge_kernel'):
                m.merge_kernel()
        self.small_kernel_merged = True
