import math
from itertools import chain
from typing import Sequence

import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.cnn.utils.weight_init import constant_init, trunc_normal_init

from ..builder import BACKBONES
from ..utils import ChannelMultiheadAttention, PositionEncodingFourier
from .convnext import ConvNeXtBlock
from .base_backbone import BaseBackbone


class SDTAEncoder(BaseModule):
    """Implementation of split depth-wise transpose attention (SDTA) encoder.

    Modified from https://github.com/mmaaz60/EdgeNeXt

    Args:
        in_channel (int): Number of input channels.
        drop_path_rate (float): Stochastic depth dropout rate.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        mlp_ratio (int): Number of channels ratio in the MLP.
            Defaults to 4.
        use_pos_emb (bool): Whether to use position encoding.
            Defaults to True.
        num_heads (int): Number of heads in the multihead attention.
            Defaults to 8.
        qkv_bias (bool): Whether to use bias in the multihead attention.
            Defaults to True.
        attn_drop (float): Dropout rate of the attention.
            Defaults to 0.
        proj_drop (float): Dropout rate of the projection.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        norm_cfg (dict): Dictionary to construct normalization layer.
            Defaults to ``dict(type='LN')``.
        act_cfg (dict): Dictionary to construct activation layer.
            Defaults to ``dict(type='GELU')``.
        scales (int): Number of scales. Default to 1.
    """

    def __init__(self,
                 in_channel,
                 drop_path_rate=0.,
                 layer_scale_init_value=1e-6,
                 mlp_ratio=4,
                 use_pos_emb=True,
                 num_heads=8,
                 qkv_bias=True,
                 attn_drop=0.,
                 proj_drop=0.,
                 norm_cfg=dict(type='LN'),
                 act_cfg=dict(type='GELU'),
                 scales=1,
                 init_cfg=None):
        super(SDTAEncoder, self).__init__(init_cfg=init_cfg)
        conv_channels = max(
            int(math.ceil(in_channel / scales)),
            int(math.floor(in_channel // scales)))
        self.conv_channels = conv_channels
        self.num_convs = scales if scales == 1 else scales - 1

        self.conv_modules = ModuleList()
        for i in range(self.num_convs):
            self.conv_modules.append(
                nn.Conv2d(
                    conv_channels,
                    conv_channels,
                    kernel_size=3,
                    padding=1,
                    groups=conv_channels))

        self.pos_embed = PositionEncodingFourier(
            embed_dims=in_channel) if use_pos_emb else None

        self.norm_csa = build_norm_layer(norm_cfg, in_channel)[1]
        self.gamma_csa = nn.Parameter(
            layer_scale_init_value * torch.ones(in_channel),
            requires_grad=True) if layer_scale_init_value > 0 else None
        self.csa = ChannelMultiheadAttention(
            embed_dims=in_channel,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=proj_drop)

        self.norm = build_norm_layer(norm_cfg, in_channel)[1]
        self.pointwise_conv1 = nn.Linear(in_channel, mlp_ratio * in_channel)
        self.act = build_activation_layer(act_cfg)
        self.pointwise_conv2 = nn.Linear(mlp_ratio * in_channel, in_channel)
        self.gamma = nn.Parameter(
            layer_scale_init_value * torch.ones(in_channel),
            requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(
            drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        shortcut = x
        spx = torch.split(x, self.conv_channels, dim=1)
        for i in range(self.num_convs):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.conv_modules[i](sp)
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)

        x = torch.cat((out, spx[self.num_convs]), 1)

        # Channel Self-attention
        B, C, H, W = x.shape
        x = x.reshape(B, C, H * W).permute(0, 2, 1)
        if self.pos_embed:
            pos_encoding = self.pos_embed((B, H, W))
            pos_encoding = pos_encoding.reshape(B, -1,
                                                x.shape[1]).permute(0, 2, 1)
            x += pos_encoding

        x = x + self.drop_path(self.gamma_csa * self.csa(self.norm_csa(x)))
        x = x.reshape(B, H, W, C)

        # Inverted Bottleneck
        x = self.norm(x)
        x = self.pointwise_conv1(x)
        x = self.act(x)
        x = self.pointwise_conv2(x)

        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (B, H, W, C) -> (B, C, H, W)

        x = shortcut + self.drop_path(x)

        return x


@BACKBONES.register_module()
class EdgeNeXt(BaseBackbone):
    """EdgeNeXt.

    A PyTorch implementation of : `EdgeNeXt: Efficiently
    Amalgamated CNN-Transformer Architecture for Mobile Vision Applications
    <https://arxiv.org/abs/2206.10589>`_

    Inspiration from https://github.com/mmaaz60/EdgeNeXt

    Args:
        arch (str | dict): The model's architecture. If string, it should be
            one of architectures in ``EdgeNeXt.arch_settings``.
            And if dict, it should include the following two keys:

            - channels (list[int]): The number of channels at each stage.
            - depths (list[int]): The number of blocks at each stage.
            - num_heads (list[int]): The number of heads at each stage.

            Defaults to 'xxsmall'.
        in_channels (int): The number of input channels.
            Defaults to 3.
        global_blocks (list[int]): The number of global blocks.
            Defaults to [0, 1, 1, 1].
        global_block_type (list[str]): The type of global blocks.
            Defaults to ['None', 'SDTA', 'SDTA', 'SDTA'].
        drop_path_rate (float): Stochastic depth dropout rate.
            Defaults to 0.
        layer_scale_init_value (float): Initial value of layer scale.
            Defaults to 1e-6.
        linear_pw_conv (bool): Whether to use linear layer to do pointwise
            convolution. Defaults to False.
        mlp_ratio (int): The number of channel ratio in MLP layers.
            Defaults to 4.
        conv_kernel_size (list[int]): The kernel size of convolutional layers
            at each stage. Defaults to [3, 5, 7, 9].
        use_pos_embd_csa (list[bool]): Whether to use positional embedding in
            Channel Self-Attention. Defaults to [False, True, False, False].
        use_pos_emebd_global (bool): Whether to use positional embedding for
            whole network. Defaults to False.
        d2_scales (list[int]): The number of channel groups used for SDTA at
            each stage. Defaults to [2, 2, 3, 4].
        norm_cfg (dict): The config of normalization layer.
            Defaults to ``dict(type='LN2d', eps=1e-6)``.
        out_indices (Sequence | int): Output from which stages.
            Defaults to -1, means the last stage.
        frozen_stages (int): Stages to be frozen (all param fixed).
            Defaults to 0, which means not freezing any parameters.
        gap_before_final_norm (bool): Whether to globally average the feature
            map before the final norm layer. Defaults to True.
        act_cfg (dict): The config of activation layer.
            Defaults to ``dict(type='GELU')``.
        init_cfg (dict, optional): Config for initialization.
            Defaults to None.
    """
    arch_settings = {
        'xxsmall': {  # parameters: 1.3M
            'channels': [24, 48, 88, 168],
            'depths': [2, 2, 6, 2],
            'num_heads': [4, 4, 4, 4]
        },
        'xsmall': {  # parameters: 2.3M
            'channels': [32, 64, 100, 192],
            'depths': [3, 3, 9, 3],
            'num_heads': [4, 4, 4, 4]
        },
        'small': {  # parameters: 5.6M
            'channels': [48, 96, 160, 304],
            'depths': [3, 3, 9, 3],
            'num_heads': [8, 8, 8, 8]
        },
        'base': {  # parameters: 18.51M
            'channels': [80, 160, 288, 584],
            'depths': [3, 3, 9, 3],
            'num_heads': [8, 8, 8, 8]
        },
    }

    def __init__(self,
                 arch='xxsmall',
                 in_channels=3,
                 global_blocks=[0, 1, 1, 1],
                 global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
                 drop_path_rate=0.,
                 layer_scale_init_value=1e-6,
                 linear_pw_conv=True,
                 mlp_ratio=4,
                 conv_kernel_sizes=[3, 5, 7, 9],
                 use_pos_embd_csa=[False, True, False, False],
                 use_pos_embd_global=False,
                 d2_scales=[2, 2, 3, 4],
                 norm_cfg=dict(type='LN2d', eps=1e-6),
                 out_indices=-1,
                 frozen_stages=0,
                 gap_before_final_norm=True,
                 act_cfg=dict(type='GELU'),
                 init_cfg=None,
                 **kwargs):
        super(EdgeNeXt, self).__init__(init_cfg=init_cfg)

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in self.arch_settings, \
                f'Arch {arch} is not in default archs ' \
                f'{set(self.arch_settings)}'
            self.arch_settings = self.arch_settings[arch]
        elif isinstance(arch, dict):
            essential_keys = {'channels', 'depths', 'num_heads'}
            assert isinstance(arch, dict) and set(arch) == essential_keys, \
                f'Custom arch needs a dict with keys {essential_keys}'
            self.arch_settings = arch

        self.channels = self.arch_settings['channels']
        self.depths = self.arch_settings['depths']
        self.num_heads = self.arch_settings['num_heads']
        self.num_layers = len(self.depths)
        self.use_pos_embd_global = use_pos_embd_global

        for g in global_block_type:
            assert g in ['None',
                         'SDTA'], f'Global block type {g} is not supported'

        self.num_stages = len(self.depths)

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), \
            f'"out_indices" must by a sequence or int, ' \
            f'get {type(out_indices)} instead.'
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = 4 + index
                assert out_indices[i] >= 0, f'Invalid out_indices {index}'
        self.out_indices = out_indices

        self.frozen_stages = frozen_stages
        self.gap_before_final_norm = gap_before_final_norm

        if self.use_pos_embd_global:
            self.pos_embed = PositionEncodingFourier(
                embed_dims=self.channels[0])
        else:
            self.pos_embed = None

        # stochastic depth decay rule
        dpr = [
            x.item()
            for x in torch.linspace(0, drop_path_rate, sum(self.depths))
        ]

        self.downsample_layers = ModuleList()
        stem = nn.Sequential(
            nn.Conv2d(in_channels, self.channels[0], kernel_size=4, stride=4),
            build_norm_layer(norm_cfg, self.channels[0])[1],
        )
        self.downsample_layers.append(stem)

        self.stages = ModuleList()
        block_idx = 0
        for i in range(self.num_stages):
            depth = self.depths[i]
            channels = self.channels[i]

            if i >= 1:
                downsample_layer = nn.Sequential(
                    build_norm_layer(norm_cfg, self.channels[i - 1])[1],
                    nn.Conv2d(
                        self.channels[i - 1],
                        channels,
                        kernel_size=2,
                        stride=2,
                    ))
                self.downsample_layers.append(downsample_layer)

            stage_blocks = []
            for j in range(depth):
                if j > depth - global_blocks[i] - 1:
                    stage_blocks.append(
                        SDTAEncoder(
                            in_channel=channels,
                            drop_path_rate=dpr[block_idx + j],
                            mlp_ratio=mlp_ratio,
                            scales=d2_scales[i],
                            use_pos_emb=use_pos_embd_csa[i],
                            num_heads=self.num_heads[i],
                        ))
                else:
                    dw_conv_cfg = dict(
                        kernel_size=conv_kernel_sizes[i],
                        padding=conv_kernel_sizes[i] // 2,
                    )
                    stage_blocks.append(
                        ConvNeXtBlock(
                            in_channels=channels,
                            dw_conv_cfg=dw_conv_cfg,
                            norm_cfg=norm_cfg,
                            act_cfg=act_cfg,
                            linear_pw_conv=linear_pw_conv,
                            drop_path_rate=dpr[block_idx + j],
                            layer_scale_init_value=layer_scale_init_value,
                        ))
            block_idx += depth

            stage_blocks = Sequential(*stage_blocks)
            self.stages.append(stage_blocks)

            if i in self.out_indices:
                out_norm_cfg = dict(type='LN') if self.gap_before_final_norm \
                    else norm_cfg
                norm_layer = build_norm_layer(out_norm_cfg, channels)[1]
                self.add_module(f'norm{i}', norm_layer)

    def init_weights(self, pretrained=None):
        super(EdgeNeXt, self).init_weights(pretrained)

        if pretrained is None:
            if self.init_cfg is None:
                for m in self.modules():
                    if isinstance(m, (nn.Linear)):
                        trunc_normal_init(m, std=0.02)
                    elif isinstance(m, (
                        nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
                        constant_init(m, val=1, bias=0)

    def _freeze_stages(self):
        for i in range(self.frozen_stages):
            downsample_layer = self.downsample_layers[i]
            stage = self.stages[i]
            downsample_layer.eval()
            stage.eval()
            for param in chain(downsample_layer.parameters(),
                               stage.parameters()):
                param.requires_grad = False

    def forward(self, x):
        outs = []
        for i, stage in enumerate(self.stages):
            x = self.downsample_layers[i](x)
            x = stage(x)
            if self.pos_embed and i == 0:
                B, _, H, W = x.shape
                x += self.pos_embed((B, H, W))

            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                if self.gap_before_final_norm:
                    gap = x.mean([-2, -1], keepdim=True)
                    outs.append(norm_layer(gap.flatten(1)))
                else:
                    # The output of LayerNorm2d may be discontiguous, which
                    # may cause some problem in the downstream tasks
                    outs.append(norm_layer(x).contiguous())

        return outs

    def train(self, mode=True):
        super(EdgeNeXt, self).train(mode)
        self._freeze_stages()
