import math
from typing import List
import torch
from torch import nn
from torch.nn import functional as F

from detectron2.layers import ShapeSpec, get_norm
from detectron2.config import configurable
from ..layers.deform_conv import DFConv2d

__all__ = ["CenterNetHead"]

class Scale(nn.Module):
    def __init__(self, init_value=1.0):
        super(Scale, self).__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_value]))

    def forward(self, input):
        return input * self.scale

class CenterNetHead(nn.Module):
    @configurable
    def __init__(self, 
        # input_shape: List[ShapeSpec],
        in_channels,
        num_levels,
        *,
        num_classes=80,
        with_agn_hm=False,
        only_proposal=False,
        norm='GN',
        num_cls_convs=4,
        num_box_convs=4,
        num_share_convs=0,
        use_deformable=False,
        prior_prob=0.01):
        super().__init__()
        self.num_classes = num_classes
        self.with_agn_hm = with_agn_hm
        self.only_proposal = only_proposal
        self.out_kernel = 3

        head_configs = {
            "cls": (num_cls_convs if not self.only_proposal else 0, \
                use_deformable),
            "bbox": (num_box_convs, use_deformable),
            "share": (num_share_convs, use_deformable)}

        # in_channels = [s.channels for s in input_shape]
        # assert len(set(in_channels)) == 1, \
        #     "Each level must have the same channel!"
        # in_channels = in_channels[0]
        channels = {
            'cls': in_channels,
            'bbox': in_channels,
            'share': in_channels,
        }
        for head in head_configs:
            tower = []
            num_convs, use_deformable = head_configs[head]
            channel = channels[head]
            for i in range(num_convs):
                if use_deformable and i == num_convs - 1:
                    conv_func = DFConv2d
                else:
                    conv_func = nn.Conv2d
                tower.append(conv_func(
                        in_channels if i == 0 else channel,
                        channel, 
                        kernel_size=3, stride=1,
                        padding=1, bias=True
                ))
                if norm == 'GN' and channel % 32 != 0:
                    tower.append(nn.GroupNorm(25, channel))
                elif norm != '':
                    tower.append(get_norm(norm, channel))
                tower.append(nn.ReLU())
            self.add_module('{}_tower'.format(head),
                            nn.Sequential(*tower))

        self.bbox_pred = nn.Conv2d(
            in_channels, 4, kernel_size=self.out_kernel,
            stride=1, padding=self.out_kernel // 2
        )

        self.scales = nn.ModuleList(
            [Scale(init_value=1.0) for _ in range(num_levels)])

        for modules in [
            self.cls_tower, self.bbox_tower,
            self.share_tower,
            self.bbox_pred,
        ]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)
        
        torch.nn.init.constant_(self.bbox_pred.bias, 8.)
        prior_prob = prior_prob
        bias_value = -math.log((1 - prior_prob) / prior_prob)

        if self.with_agn_hm:
            self.agn_hm = nn.Conv2d(
                in_channels, 1, kernel_size=self.out_kernel,
                stride=1, padding=self.out_kernel // 2
            )
            torch.nn.init.constant_(self.agn_hm.bias, bias_value)
            torch.nn.init.normal_(self.agn_hm.weight, std=0.01)

        if not self.only_proposal:
            cls_kernel_size = self.out_kernel
            self.cls_logits = nn.Conv2d(
                in_channels, self.num_classes,
                kernel_size=cls_kernel_size, 
                stride=1,
                padding=cls_kernel_size // 2,
            )

            torch.nn.init.constant_(self.cls_logits.bias, bias_value)
            torch.nn.init.normal_(self.cls_logits.weight, std=0.01)

    @classmethod
    def from_config(cls, cfg, input_shape):
        ret = {
            # 'input_shape': input_shape,
            'in_channels': [s.channels for s in input_shape][0],
            'num_levels': len(input_shape),
            'num_classes': cfg.MODEL.CENTERNET.NUM_CLASSES,
            'with_agn_hm': cfg.MODEL.CENTERNET.WITH_AGN_HM,
            'only_proposal': cfg.MODEL.CENTERNET.ONLY_PROPOSAL,
            'norm': cfg.MODEL.CENTERNET.NORM,
            'num_cls_convs': cfg.MODEL.CENTERNET.NUM_CLS_CONVS,
            'num_box_convs': cfg.MODEL.CENTERNET.NUM_BOX_CONVS,
            'num_share_convs': cfg.MODEL.CENTERNET.NUM_SHARE_CONVS,
            'use_deformable': cfg.MODEL.CENTERNET.USE_DEFORMABLE,
            'prior_prob': cfg.MODEL.CENTERNET.PRIOR_PROB,
        }
        return ret

    def forward(self, x):
        clss = []
        bbox_reg = []
        agn_hms = []
        for l, feature in enumerate(x):
            feature = self.share_tower(feature)
            cls_tower = self.cls_tower(feature)
            bbox_tower = self.bbox_tower(feature)
            if not self.only_proposal:
                clss.append(self.cls_logits(cls_tower))
            else:
                clss.append(None)

            if self.with_agn_hm:
                agn_hms.append(self.agn_hm(bbox_tower))
            else:
                agn_hms.append(None)
            reg = self.bbox_pred(bbox_tower)
            reg = self.scales[l](reg)
            bbox_reg.append(F.relu(reg))
        
        return clss, bbox_reg, agn_hms