# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

from cvpods.layers import Conv2d, ShapeSpec, get_norm
from cvpods.modeling.nn_utils import weight_init


"""
Registry for box heads, which make box predictions from per-region features.

The registered object will be called with `obj(cfg, input_shape)`.
"""


class FastRCNNConvFCHead(nn.Module):
    """
    A head with several 3x3 conv layers (each followed by norm & relu) and
    several fc layers (each followed by relu).
    """

    def __init__(self, cfg, input_shape: ShapeSpec, param_dicts=None):
        """
        The following attributes are parsed from config:
            num_conv, num_fc: the number of conv/fc layers
            conv_dim/fc_dim: the dimension of the conv/fc layers
            norm: normalization for the conv layers
        """
        super().__init__()

        # fmt: off
        if param_dicts is None:
            param_dicts = cfg.MODEL.ROI_BOX_HEAD
            
        num_conv   = param_dicts.NUM_CONV 
        conv_dim   = param_dicts.CONV_DIM
        num_fc     = param_dicts.NUM_FC
        fc_dim     = param_dicts.FC_DIM
        norm       = param_dicts.NORM
        # fmt: on
        assert num_conv + num_fc > 0

        self._output_size = (
            input_shape.channels,
            input_shape.height,
            input_shape.width,
        )

        self.conv_norm_relus = []
        for k in range(num_conv):
            conv = Conv2d(
                    self._output_size[0],
                    conv_dim,
                    kernel_size=3,
                    padding=1,
                    bias=not norm,
                    norm=get_norm(norm, conv_dim),
                    activation=F.relu,
                )
            self.add_module("conv{}".format(k + 1), conv)
            self.conv_norm_relus.append(conv)

            self._output_size = (conv_dim, self._output_size[1], self._output_size[2])

        self.fcs = []
        for k in range(num_fc):
            fc = nn.Linear(np.prod(self._output_size), fc_dim)
            self.add_module("fc{}".format(k + 1), fc)
            self.fcs.append(fc)
            self._output_size = fc_dim

        for idx, layer in enumerate(self.conv_norm_relus):
            weight_init.c2_msra_fill(layer)
            
        for layer in self.fcs:
            weight_init.c2_xavier_fill(layer)

    def forward(self, x):
        for layer in self.conv_norm_relus:
            x = layer(x)
        if len(self.fcs):
            if x.dim() > 2:
                x = torch.flatten(x, start_dim=1)
            for layer in self.fcs:
                x = F.relu(layer(x))
        return x

    @property
    def output_size(self):
        return self._output_size
