# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/Kevinz-code/CSRA
from typing import Tuple

import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleList

from mmpretrain.registry import MODELS
from .multi_label_cls_head import MultiLabelClsHead


@MODELS.register_module()
class CSRAClsHead(MultiLabelClsHead):
    """Class-specific residual attention classifier head.

    Please refer to the `Residual Attention: A Simple but Effective Method for
    Multi-Label Recognition (ICCV 2021) <https://arxiv.org/abs/2108.02456>`_
    for details.

    Args:
        num_classes (int): Number of categories.
        in_channels (int): Number of channels in the input feature map.
        num_heads (int): Number of residual at tensor heads.
        loss (dict): Config of classification loss.
        lam (float): Lambda that combines global average and max pooling
            scores.
        init_cfg (dict, optional): The extra init config of layers.
            Defaults to use ``dict(type='Normal', layer='Linear', std=0.01)``.
    """
    temperature_settings = {  # softmax temperature settings
        1: [1],
        2: [1, 99],
        4: [1, 2, 4, 99],
        6: [1, 2, 3, 4, 5, 99],
        8: [1, 2, 3, 4, 5, 6, 7, 99]
    }

    def __init__(self,
                 num_classes: int,
                 in_channels: int,
                 num_heads: int,
                 lam: float,
                 init_cfg=dict(type='Normal', layer='Linear', std=0.01),
                 **kwargs):
        assert num_heads in self.temperature_settings.keys(
        ), 'The num of heads is not in temperature setting.'
        assert lam > 0, 'Lambda should be between 0 and 1.'
        super(CSRAClsHead, self).__init__(init_cfg=init_cfg, **kwargs)
        self.temp_list = self.temperature_settings[num_heads]
        self.csra_heads = ModuleList([
            CSRAModule(num_classes, in_channels, self.temp_list[i], lam)
            for i in range(num_heads)
        ])

    def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The process before the final classification head.

        The input ``feats`` is a tuple of tensor, and each tensor is the
        feature of a backbone stage. In ``CSRAClsHead``, we just obtain the
        feature of the last stage.
        """
        # The CSRAClsHead doesn't have other module, just return after
        # unpacking.
        return feats[-1]

    def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The forward process."""
        pre_logits = self.pre_logits(feats)
        logit = sum([head(pre_logits) for head in self.csra_heads])
        return logit


class CSRAModule(BaseModule):
    """Basic module of CSRA with different temperature.

    Args:
        num_classes (int): Number of categories.
        in_channels (int): Number of channels in the input feature map.
        T (int): Temperature setting.
        lam (float): Lambda that combines global average and max pooling
            scores.
        init_cfg (dict | optional): The extra init config of layers.
            Defaults to use dict(type='Normal', layer='Linear', std=0.01).
    """

    def __init__(self,
                 num_classes: int,
                 in_channels: int,
                 T: int,
                 lam: float,
                 init_cfg=None):

        super(CSRAModule, self).__init__(init_cfg=init_cfg)
        self.T = T  # temperature
        self.lam = lam  # Lambda
        self.head = nn.Conv2d(in_channels, num_classes, 1, bias=False)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        score = self.head(x) / torch.norm(
            self.head.weight, dim=1, keepdim=True).transpose(0, 1)
        score = score.flatten(2)
        base_logit = torch.mean(score, dim=2)

        if self.T == 99:  # max-pooling
            att_logit = torch.max(score, dim=2)[0]
        else:
            score_soft = self.softmax(score * self.T)
            att_logit = torch.sum(score * score_soft, dim=2)

        return base_logit + self.lam * att_logit
