# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.

import fvcore.nn.weight_init as weight_init
from torch import nn

from .batch_norm import FrozenBatchNorm2d, get_norm
from .wrappers import Conv2d


"""
CNN building blocks.
"""


class CNNBlockBase(nn.Module):
    """
    A CNN block is assumed to have input channels, output channels and a stride.
    The input and output of `forward()` method must be NCHW tensors.
    The method can perform arbitrary computation but must match the given
    channels and stride specification.

    Attribute:
        in_channels (int):
        out_channels (int):
        stride (int):
    """

    def __init__(self, in_channels, out_channels, stride):
        """
        The `__init__` method of any subclass should also contain these arguments.

        Args:
            in_channels (int):
            out_channels (int):
            stride (int):
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride

    def freeze(self):
        """
        Make this block not trainable.
        This method sets all parameters to `requires_grad=False`,
        and convert all BatchNorm layers to FrozenBatchNorm

        Returns:
            the block itself
        """
        for p in self.parameters():
            p.requires_grad = False
        FrozenBatchNorm2d.convert_frozen_batchnorm(self)
        return self


class DepthwiseSeparableConv2d(nn.Module):
    """
    A kxk depthwise convolution + a 1x1 convolution.

    In :paper:`xception`, norm & activation are applied on the second conv.
    :paper:`mobilenet` uses norm & activation on both convs.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        padding=1,
        dilation=1,
        *,
        norm1=None,
        activation1=None,
        norm2=None,
        activation2=None,
    ):
        """
        Args:
            norm1, norm2 (str or callable): normalization for the two conv layers.
            activation1, activation2 (callable(Tensor) -> Tensor): activation
                function for the two conv layers.
        """
        super().__init__()
        self.depthwise = Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=not norm1,
            norm=get_norm(norm1, in_channels),
            activation=activation1,
        )
        self.pointwise = Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            bias=not norm2,
            norm=get_norm(norm2, out_channels),
            activation=activation2,
        )

        # default initialization
        weight_init.c2_msra_fill(self.depthwise)
        weight_init.c2_msra_fill(self.pointwise)

    def forward(self, x):
        return self.pointwise(self.depthwise(x))
