# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm

from mmpose.utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
from .utils import load_checkpoint


class CpmBlock(nn.Module):
    """CpmBlock for Convolutional Pose Machine.

    Args:
        in_channels (int): Input channels of this block.
        channels (list): Output channels of each conv module.
        kernels (list): Kernel sizes of each conv module.
    """

    def __init__(self,
                 in_channels,
                 channels=(128, 128, 128),
                 kernels=(11, 11, 11),
                 norm_cfg=None):
        super().__init__()

        assert len(channels) == len(kernels)
        layers = []
        for i in range(len(channels)):
            if i == 0:
                input_channels = in_channels
            else:
                input_channels = channels[i - 1]
            layers.append(
                ConvModule(
                    input_channels,
                    channels[i],
                    kernels[i],
                    padding=(kernels[i] - 1) // 2,
                    norm_cfg=norm_cfg))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        """Model forward function."""
        out = self.model(x)
        return out


@BACKBONES.register_module()
class CPM(BaseBackbone):
    """CPM backbone.

    Convolutional Pose Machines.
    More details can be found in the `paper
    <https://arxiv.org/abs/1602.00134>`__ .

    Args:
        in_channels (int): The input channels of the CPM.
        out_channels (int): The output channels of the CPM.
        feat_channels (int): Feature channel of each CPM stage.
        middle_channels (int): Feature channel of conv after the middle stage.
        num_stages (int): Number of stages.
        norm_cfg (dict): Dictionary to construct and config norm layer.

    Example:
        >>> from mmpose.models import CPM
        >>> import torch
        >>> self = CPM(3, 17)
        >>> self.eval()
        >>> inputs = torch.rand(1, 3, 368, 368)
        >>> level_outputs = self.forward(inputs)
        >>> for level_output in level_outputs:
        ...     print(tuple(level_output.shape))
        (1, 17, 46, 46)
        (1, 17, 46, 46)
        (1, 17, 46, 46)
        (1, 17, 46, 46)
        (1, 17, 46, 46)
        (1, 17, 46, 46)
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 feat_channels=128,
                 middle_channels=32,
                 num_stages=6,
                 norm_cfg=dict(type='BN', requires_grad=True)):
        # Protect mutable default arguments
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__()

        assert in_channels == 3

        self.num_stages = num_stages
        assert self.num_stages >= 1

        self.stem = nn.Sequential(
            ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg),
            ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg),
            ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg),
            ConvModule(512, out_channels, 1, padding=0, act_cfg=None))

        self.middle = nn.Sequential(
            ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.cpm_stages = nn.ModuleList([
            CpmBlock(
                middle_channels + out_channels,
                channels=[feat_channels, feat_channels, feat_channels],
                kernels=[11, 11, 11],
                norm_cfg=norm_cfg) for _ in range(num_stages - 1)
        ])

        self.middle_conv = nn.ModuleList([
            nn.Sequential(
                ConvModule(
                    128, middle_channels, 5, padding=2, norm_cfg=norm_cfg))
            for _ in range(num_stages - 1)
        ])

        self.out_convs = nn.ModuleList([
            nn.Sequential(
                ConvModule(
                    feat_channels,
                    feat_channels,
                    1,
                    padding=0,
                    norm_cfg=norm_cfg),
                ConvModule(feat_channels, out_channels, 1, act_cfg=None))
            for _ in range(num_stages - 1)
        ])

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    normal_init(m, std=0.001)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward(self, x):
        """Model forward function."""
        stage1_out = self.stem(x)
        middle_out = self.middle(x)
        out_feats = []

        out_feats.append(stage1_out)

        for ind in range(self.num_stages - 1):
            single_stage = self.cpm_stages[ind]
            out_conv = self.out_convs[ind]

            inp_feat = torch.cat(
                [out_feats[-1], self.middle_conv[ind](middle_out)], 1)
            cpm_feat = single_stage(inp_feat)
            out_feat = out_conv(cpm_feat)
            out_feats.append(out_feat)

        return out_feats
