import torch
from mmcv.cnn import ContextBlock

from ..builder import HEADS
from .fcn_head import FCNHead


@HEADS.register_module()
class GCHead(FCNHead):
    """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.

    This head is the implementation of `GCNet
    <https://arxiv.org/abs/1904.11492>`_.

    Args:
        ratio (float): Multiplier of channels ratio. Default: 1/4.
        pooling_type (str): The pooling type of context aggregation.
            Options are 'att', 'avg'. Default: 'avg'.
        fusion_types (tuple[str]): The fusion type for feature fusion.
            Options are 'channel_add', 'channel_mul'. Defautl: ('channel_add',)
    """

    def __init__(self,
                 ratio=1 / 4.,
                 pooling_type='att',
                 fusion_types=('channel_add', ),
                 **kwargs):
        super(GCHead, self).__init__(num_convs=2, **kwargs)
        self.ratio = ratio
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        self.gc_block = ContextBlock(
            in_channels=self.channels,
            ratio=self.ratio,
            pooling_type=self.pooling_type,
            fusion_types=self.fusion_types)

    def forward(self, inputs):
        """Forward function."""
        x = self._transform_inputs(inputs)
        output = self.convs[0](x)
        output = self.gc_block(output)
        output = self.convs[1](output)
        if self.concat_input:
            output = self.conv_cat(torch.cat([x, output], dim=1))
        output = self.cls_seg(output)
        return output
