# Copyright (c) OpenMMLab. All rights reserved.
import copy

import torch.nn as nn
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmpose.core import WeightNormClipHook
from ..builder import BACKBONES
from .base_backbone import BaseBackbone


class BasicTemporalBlock(nn.Module):
    """Basic block for VideoPose3D.

    Args:
        in_channels (int): Input channels of this block.
        out_channels (int): Output channels of this block.
        mid_channels (int): The output channels of conv1. Default: 1024.
        kernel_size (int): Size of the convolving kernel. Default: 3.
        dilation (int): Spacing between kernel elements. Default: 3.
        dropout (float): Dropout rate. Default: 0.25.
        causal (bool): Use causal convolutions instead of symmetric
            convolutions (for real-time applications). Default: False.
        residual (bool): Use residual connection. Default: True.
        use_stride_conv (bool): Use optimized TCN that designed
            specifically for single-frame batching, i.e. where batches have
            input length = receptive field, and output length = 1. This
            implementation replaces dilated convolutions with strided
            convolutions to avoid generating unused intermediate results.
            Default: False.
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: dict(type='Conv1d').
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN1d').
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 mid_channels=1024,
                 kernel_size=3,
                 dilation=3,
                 dropout=0.25,
                 causal=False,
                 residual=True,
                 use_stride_conv=False,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d')):
        # Protect mutable default arguments
        conv_cfg = copy.deepcopy(conv_cfg)
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.dropout = dropout
        self.causal = causal
        self.residual = residual
        self.use_stride_conv = use_stride_conv

        self.pad = (kernel_size - 1) * dilation // 2
        if use_stride_conv:
            self.stride = kernel_size
            self.causal_shift = kernel_size // 2 if causal else 0
            self.dilation = 1
        else:
            self.stride = 1
            self.causal_shift = kernel_size // 2 * dilation if causal else 0

        self.conv1 = nn.Sequential(
            ConvModule(
                in_channels,
                mid_channels,
                kernel_size=kernel_size,
                stride=self.stride,
                dilation=self.dilation,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))
        self.conv2 = nn.Sequential(
            ConvModule(
                mid_channels,
                out_channels,
                kernel_size=1,
                bias='auto',
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg))

        if residual and in_channels != out_channels:
            self.short_cut = build_conv_layer(conv_cfg, in_channels,
                                              out_channels, 1)
        else:
            self.short_cut = None

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        """Forward function."""
        if self.use_stride_conv:
            assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
        else:
            assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
                self.pad + self.causal_shift <= x.shape[2]

        out = self.conv1(x)
        if self.dropout is not None:
            out = self.dropout(out)

        out = self.conv2(out)
        if self.dropout is not None:
            out = self.dropout(out)

        if self.residual:
            if self.use_stride_conv:
                res = x[:, :, self.causal_shift +
                        self.kernel_size // 2::self.kernel_size]
            else:
                res = x[:, :,
                        (self.pad + self.causal_shift):(x.shape[2] - self.pad +
                                                        self.causal_shift)]

            if self.short_cut is not None:
                res = self.short_cut(res)
            out = out + res

        return out


@BACKBONES.register_module()
class TCN(BaseBackbone):
    """TCN backbone.

    Temporal Convolutional Networks.
    More details can be found in the
    `paper <https://arxiv.org/abs/1811.11742>`__ .

    Args:
        in_channels (int): Number of input channels, which equals to
            num_keypoints * num_features.
        stem_channels (int): Number of feature channels. Default: 1024.
        num_blocks (int): NUmber of basic temporal convolutional blocks.
            Default: 2.
        kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
            each basic block. Default: ``(3, 3, 3)``.
        dropout (float): Dropout rate. Default: 0.25.
        causal (bool): Use causal convolutions instead of symmetric
            convolutions (for real-time applications).
            Default: False.
        residual (bool): Use residual connection. Default: True.
        use_stride_conv (bool): Use TCN backbone optimized for
            single-frame batching, i.e. where batches have input length =
            receptive field, and output length = 1. This implementation
            replaces dilated convolutions with strided convolutions to avoid
            generating unused intermediate results. The weights are
            interchangeable with the reference implementation. Default: False
        conv_cfg (dict): dictionary to construct and config conv layer.
            Default: dict(type='Conv1d').
        norm_cfg (dict): dictionary to construct and config norm layer.
            Default: dict(type='BN1d').
        max_norm (float|None): if not None, the weight of convolution layers
            will be clipped to have a maximum norm of max_norm.

    Example:
        >>> from mmpose.models import TCN
        >>> import torch
        >>> self = TCN(in_channels=34)
        >>> self.eval()
        >>> inputs = torch.rand(1, 34, 243)
        >>> level_outputs = self.forward(inputs)
        >>> for level_out in level_outputs:
        ...     print(tuple(level_out.shape))
        (1, 1024, 235)
        (1, 1024, 217)
    """

    def __init__(self,
                 in_channels,
                 stem_channels=1024,
                 num_blocks=2,
                 kernel_sizes=(3, 3, 3),
                 dropout=0.25,
                 causal=False,
                 residual=True,
                 use_stride_conv=False,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 max_norm=None):
        # Protect mutable default arguments
        conv_cfg = copy.deepcopy(conv_cfg)
        norm_cfg = copy.deepcopy(norm_cfg)
        super().__init__()
        self.in_channels = in_channels
        self.stem_channels = stem_channels
        self.num_blocks = num_blocks
        self.kernel_sizes = kernel_sizes
        self.dropout = dropout
        self.causal = causal
        self.residual = residual
        self.use_stride_conv = use_stride_conv
        self.max_norm = max_norm

        assert num_blocks == len(kernel_sizes) - 1
        for ks in kernel_sizes:
            assert ks % 2 == 1, 'Only odd filter widths are supported.'

        self.expand_conv = ConvModule(
            in_channels,
            stem_channels,
            kernel_size=kernel_sizes[0],
            stride=kernel_sizes[0] if use_stride_conv else 1,
            bias='auto',
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg)

        dilation = kernel_sizes[0]
        self.tcn_blocks = nn.ModuleList()
        for i in range(1, num_blocks + 1):
            self.tcn_blocks.append(
                BasicTemporalBlock(
                    in_channels=stem_channels,
                    out_channels=stem_channels,
                    mid_channels=stem_channels,
                    kernel_size=kernel_sizes[i],
                    dilation=dilation,
                    dropout=dropout,
                    causal=causal,
                    residual=residual,
                    use_stride_conv=use_stride_conv,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg))
            dilation *= kernel_sizes[i]

        if self.max_norm is not None:
            # Apply weight norm clip to conv layers
            weight_clip = WeightNormClipHook(self.max_norm)
            for module in self.modules():
                if isinstance(module, nn.modules.conv._ConvNd):
                    weight_clip.register(module)

        self.dropout = nn.Dropout(dropout) if dropout > 0 else None

    def forward(self, x):
        """Forward function."""
        x = self.expand_conv(x)

        if self.dropout is not None:
            x = self.dropout(x)

        outs = []
        for i in range(self.num_blocks):
            x = self.tcn_blocks[i](x)
            outs.append(x)

        return tuple(outs)

    def init_weights(self, pretrained=None):
        """Initialize the weights."""
        super().init_weights(pretrained)
        if pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.modules.conv._ConvNd):
                    kaiming_init(m, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)
