import copy

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.utils import _BatchNorm

from mmaction.models.common import LFB
from mmaction.utils import get_root_logger

try:
    from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
    mmdet_imported = True
except (ImportError, ModuleNotFoundError):
    mmdet_imported = False


class NonLocalLayer(nn.Module):
    """Non-local layer used in `FBONonLocal` is a variation of the vanilla non-
    local block.

    Args:
        st_feat_channels (int): Channels of short-term features.
        lt_feat_channels (int): Channels of long-term features.
        latent_channels (int): Channels of latent features.
        use_scale (bool): Whether to scale pairwise_weight by
            `1/sqrt(latent_channels)`. Default: True.
        pre_activate (bool): Whether to use the activation function before
            upsampling. Default: False.
        conv_cfg (Dict | None): The config dict for convolution layers. If
            not specified, it will use `nn.Conv2d` for convolution layers.
            Default: None.
        norm_cfg (Dict | None): he config dict for normalization layers.
            Default: None.
        dropout_ratio (float, optional): Probability of dropout layer.
            Default: 0.2.
        zero_init_out_conv (bool): Whether to use zero initialization for
            out_conv. Default: False.
    """

    def __init__(self,
                 st_feat_channels,
                 lt_feat_channels,
                 latent_channels,
                 num_st_feat,
                 num_lt_feat,
                 use_scale=True,
                 pre_activate=True,
                 pre_activate_with_ln=True,
                 conv_cfg=None,
                 norm_cfg=None,
                 dropout_ratio=0.2,
                 zero_init_out_conv=False):
        super().__init__()
        if conv_cfg is None:
            conv_cfg = dict(type='Conv3d')
        self.st_feat_channels = st_feat_channels
        self.lt_feat_channels = lt_feat_channels
        self.latent_channels = latent_channels
        self.num_st_feat = num_st_feat
        self.num_lt_feat = num_lt_feat
        self.use_scale = use_scale
        self.pre_activate = pre_activate
        self.pre_activate_with_ln = pre_activate_with_ln
        self.dropout_ratio = dropout_ratio
        self.zero_init_out_conv = zero_init_out_conv

        self.st_feat_conv = ConvModule(
            self.st_feat_channels,
            self.latent_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)

        self.lt_feat_conv = ConvModule(
            self.lt_feat_channels,
            self.latent_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)

        self.global_conv = ConvModule(
            self.lt_feat_channels,
            self.latent_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)

        if pre_activate:
            self.ln = nn.LayerNorm([latent_channels, num_st_feat, 1, 1])
        else:
            self.ln = nn.LayerNorm([st_feat_channels, num_st_feat, 1, 1])

        self.relu = nn.ReLU()

        self.out_conv = ConvModule(
            self.latent_channels,
            self.st_feat_channels,
            kernel_size=1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None)

        if self.dropout_ratio > 0:
            self.dropout = nn.Dropout(self.dropout_ratio)

    def init_weights(self, pretrained=None):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        if isinstance(pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {pretrained}')
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    kaiming_init(m)
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)
            if self.zero_init_out_conv:
                constant_init(self.out_conv, 0, bias=0)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, st_feat, lt_feat):
        n, c = st_feat.size(0), self.latent_channels
        num_st_feat, num_lt_feat = self.num_st_feat, self.num_lt_feat

        theta = self.st_feat_conv(st_feat)
        theta = theta.view(n, c, num_st_feat)

        phi = self.lt_feat_conv(lt_feat)
        phi = phi.view(n, c, num_lt_feat)

        g = self.global_conv(lt_feat)
        g = g.view(n, c, num_lt_feat)

        # (n, num_st_feat, c), (n, c, num_lt_feat)
        # -> (n, num_st_feat, num_lt_feat)
        theta_phi = torch.matmul(theta.permute(0, 2, 1), phi)
        if self.use_scale:
            theta_phi /= c**0.5

        p = theta_phi.softmax(dim=-1)

        # (n, c, num_lt_feat), (n, num_lt_feat, num_st_feat)
        # -> (n, c, num_st_feat, 1, 1)
        out = torch.matmul(g, p.permute(0, 2, 1)).view(n, c, num_st_feat, 1, 1)

        # If need to activate it before out_conv, use relu here, otherwise
        # use relu outside the non local layer.
        if self.pre_activate:
            if self.pre_activate_with_ln:
                out = self.ln(out)
            out = self.relu(out)

        out = self.out_conv(out)

        if not self.pre_activate:
            out = self.ln(out)
        if self.dropout_ratio > 0:
            out = self.dropout(out)

        return out


class FBONonLocal(nn.Module):
    """Non local feature bank operator.

    Args:
        st_feat_channels (int): Channels of short-term features.
        lt_feat_channels (int): Channels of long-term features.
        latent_channels (int): Channles of latent features.
        num_st_feat (int): Number of short-term roi features.
        num_lt_feat (int): Number of long-term roi features.
        num_non_local_layers (int): Number of non-local layers, which is
            at least 1. Default: 2.
        st_feat_dropout_ratio (float): Probability of dropout layer for
            short-term features. Default: 0.2.
        lt_feat_dropout_ratio (float): Probability of dropout layer for
            long-term features. Default: 0.2.
        pre_activate (bool): Whether to use the activation function before
            upsampling in non local layers. Default: True.
        zero_init_out_conv (bool): Whether to use zero initialization for
            out_conv in NonLocalLayer. Default: False.
    """

    def __init__(self,
                 st_feat_channels,
                 lt_feat_channels,
                 latent_channels,
                 num_st_feat,
                 num_lt_feat,
                 num_non_local_layers=2,
                 st_feat_dropout_ratio=0.2,
                 lt_feat_dropout_ratio=0.2,
                 pre_activate=True,
                 zero_init_out_conv=False):
        super().__init__()
        assert num_non_local_layers >= 1, (
            'At least one non_local_layer is needed.')
        self.st_feat_channels = st_feat_channels
        self.lt_feat_channels = lt_feat_channels
        self.latent_channels = latent_channels
        self.num_st_feat = num_st_feat
        self.num_lt_feat = num_lt_feat
        self.num_non_local_layers = num_non_local_layers
        self.st_feat_dropout_ratio = st_feat_dropout_ratio
        self.lt_feat_dropout_ratio = lt_feat_dropout_ratio
        self.pre_activate = pre_activate
        self.zero_init_out_conv = zero_init_out_conv

        self.st_feat_conv = nn.Conv3d(
            st_feat_channels, latent_channels, kernel_size=1)
        self.lt_feat_conv = nn.Conv3d(
            lt_feat_channels, latent_channels, kernel_size=1)

        if self.st_feat_dropout_ratio > 0:
            self.st_feat_dropout = nn.Dropout(self.st_feat_dropout_ratio)

        if self.lt_feat_dropout_ratio > 0:
            self.lt_feat_dropout = nn.Dropout(self.lt_feat_dropout_ratio)

        if not self.pre_activate:
            self.relu = nn.ReLU()

        self.non_local_layers = []
        for idx in range(self.num_non_local_layers):
            layer_name = f'non_local_layer_{idx + 1}'
            self.add_module(
                layer_name,
                NonLocalLayer(
                    latent_channels,
                    latent_channels,
                    latent_channels,
                    num_st_feat,
                    num_lt_feat,
                    pre_activate=self.pre_activate,
                    zero_init_out_conv=self.zero_init_out_conv))
            self.non_local_layers.append(layer_name)

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            kaiming_init(self.st_feat_conv)
            kaiming_init(self.lt_feat_conv)
            for layer_name in self.non_local_layers:
                non_local_layer = getattr(self, layer_name)
                non_local_layer.init_weights(pretrained=pretrained)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, st_feat, lt_feat):
        # prepare st_feat
        st_feat = self.st_feat_conv(st_feat)
        if self.st_feat_dropout_ratio > 0:
            st_feat = self.st_feat_dropout(st_feat)

        # prepare lt_feat
        lt_feat = self.lt_feat_conv(lt_feat)
        if self.lt_feat_dropout_ratio > 0:
            lt_feat = self.lt_feat_dropout(lt_feat)

        # fuse short-term and long-term features in NonLocal Layer
        for layer_name in self.non_local_layers:
            identity = st_feat
            non_local_layer = getattr(self, layer_name)
            nl_out = non_local_layer(st_feat, lt_feat)
            nl_out = identity + nl_out
            if not self.pre_activate:
                nl_out = self.relu(nl_out)
            st_feat = nl_out

        return nl_out


class FBOAvg(nn.Module):
    """Avg pool feature bank operator."""

    def __init__(self):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d((1, None, None))

    def init_weights(self, pretrained=None):
        # FBOAvg has no parameters to be initalized.
        pass

    def forward(self, st_feat, lt_feat):
        out = self.avg_pool(lt_feat)
        return out


class FBOMax(nn.Module):
    """Max pool feature bank operator."""

    def __init__(self):
        super().__init__()
        self.max_pool = nn.AdaptiveMaxPool3d((1, None, None))

    def init_weights(self, pretrained=None):
        # FBOMax has no parameters to be initialized.
        pass

    def forward(self, st_feat, lt_feat):
        out = self.max_pool(lt_feat)
        return out


class FBOHead(nn.Module):
    """Feature Bank Operator Head.

    Add feature bank operator for the spatiotemporal detection model to fuse
    short-term features and long-term features.

    Args:
        lfb_cfg (Dict): The config dict for LFB which is used to sample
            long-term features.
        fbo_cfg (Dict): The config dict for feature bank operator (FBO). The
            type of fbo is also in the config dict and supported fbo type is
            `fbo_dict`.
        temporal_pool_type (str): The temporal pool type. Choices are 'avg' or
            'max'. Default: 'avg'.
        spatial_pool_type (str): The spatial pool type. Choices are 'avg' or
            'max'. Default: 'max'.
    """

    fbo_dict = {'non_local': FBONonLocal, 'avg': FBOAvg, 'max': FBOMax}

    def __init__(self,
                 lfb_cfg,
                 fbo_cfg,
                 temporal_pool_type='avg',
                 spatial_pool_type='max'):
        super().__init__()
        fbo_type = fbo_cfg.pop('type', 'non_local')
        assert fbo_type in FBOHead.fbo_dict
        assert temporal_pool_type in ['max', 'avg']
        assert spatial_pool_type in ['max', 'avg']

        self.lfb_cfg = copy.deepcopy(lfb_cfg)
        self.fbo_cfg = copy.deepcopy(fbo_cfg)

        self.lfb = LFB(**self.lfb_cfg)
        self.fbo = self.fbo_dict[fbo_type](**self.fbo_cfg)

        # Pool by default
        if temporal_pool_type == 'avg':
            self.temporal_pool = nn.AdaptiveAvgPool3d((1, None, None))
        else:
            self.temporal_pool = nn.AdaptiveMaxPool3d((1, None, None))
        if spatial_pool_type == 'avg':
            self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
        else:
            self.spatial_pool = nn.AdaptiveMaxPool3d((None, 1, 1))

    def init_weights(self, pretrained=None):
        """Initialize the weights in the module.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Default: None.
        """
        self.fbo.init_weights(pretrained=pretrained)

    def sample_lfb(self, rois, img_metas):
        """Sample long-term features for each ROI feature."""
        inds = rois[:, 0].type(torch.int64)
        lt_feat_list = []
        for ind in inds:
            lt_feat_list.append(self.lfb[img_metas[ind]['img_key']].to())
        lt_feat = torch.stack(lt_feat_list, dim=0)
        # [N, lfb_channels, window_size * max_num_feat_per_step]
        lt_feat = lt_feat.permute(0, 2, 1).contiguous()
        return lt_feat.unsqueeze(-1).unsqueeze(-1)

    def forward(self, x, rois, img_metas):
        # [N, C, 1, 1, 1]
        st_feat = self.temporal_pool(x)
        st_feat = self.spatial_pool(st_feat)
        identity = st_feat

        # [N, C, window_size * num_feat_per_step, 1, 1]
        lt_feat = self.sample_lfb(rois, img_metas).to(st_feat.device)

        fbo_feat = self.fbo(st_feat, lt_feat)

        out = torch.cat([identity, fbo_feat], dim=1)
        return out


if mmdet_imported:
    MMDET_SHARED_HEADS.register_module()(FBOHead)
