import torch
import torch.nn as nn
from annotator.uniformer.mmcv import is_tuple_of
from annotator.uniformer.mmcv.cnn import ConvModule

from annotator.uniformer.mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead


@HEADS.register_module()
class LRASPPHead(BaseDecodeHead):
    """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.

    This head is the improved implementation of `Searching for MobileNetV3
    <https://ieeexplore.ieee.org/document/9008835>`_.

    Args:
        branch_channels (tuple[int]): The number of output channels in every
            each branch. Default: (32, 64).
    """

    def __init__(self, branch_channels=(32, 64), **kwargs):
        super(LRASPPHead, self).__init__(**kwargs)
        if self.input_transform != 'multiple_select':
            raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
                             f'must be \'multiple_select\'. But received '
                             f'\'{self.input_transform}\'')
        assert is_tuple_of(branch_channels, int)
        assert len(branch_channels) == len(self.in_channels) - 1
        self.branch_channels = branch_channels

        self.convs = nn.Sequential()
        self.conv_ups = nn.Sequential()
        for i in range(len(branch_channels)):
            self.convs.add_module(
                f'conv{i}',
                nn.Conv2d(
                    self.in_channels[i], branch_channels[i], 1, bias=False))
            self.conv_ups.add_module(
                f'conv_up{i}',
                ConvModule(
                    self.channels + branch_channels[i],
                    self.channels,
                    1,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg,
                    bias=False))

        self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)

        self.aspp_conv = ConvModule(
            self.in_channels[-1],
            self.channels,
            1,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            bias=False)
        self.image_pool = nn.Sequential(
            nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
            ConvModule(
                self.in_channels[2],
                self.channels,
                1,
                act_cfg=dict(type='Sigmoid'),
                bias=False))

    def forward(self, inputs):
        """Forward function."""
        inputs = self._transform_inputs(inputs)

        x = inputs[-1]

        x = self.aspp_conv(x) * resize(
            self.image_pool(x),
            size=x.size()[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        x = self.conv_up_input(x)

        for i in range(len(self.branch_channels) - 1, -1, -1):
            x = resize(
                x,
                size=inputs[i].size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            x = torch.cat([x, self.convs[i](inputs[i])], 1)
            x = self.conv_ups[i](x)

        return self.cls_seg(x)
