# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn as nn

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from .cls_head import ClsHead


@MODELS.register_module()
class EfficientFormerClsHead(ClsHead):
    """EfficientFormer classifier head.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        distillation (bool): Whether use a additional distilled head.
            Defaults to True.
        init_cfg (dict): The extra initialization configs. Defaults to
            ``dict(type='Normal', layer='Linear', std=0.01)``.
    """

    def __init__(self,
                 num_classes,
                 in_channels,
                 distillation=True,
                 init_cfg=dict(type='Normal', layer='Linear', std=0.01),
                 *args,
                 **kwargs):
        super(EfficientFormerClsHead, self).__init__(
            init_cfg=init_cfg, *args, **kwargs)
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.dist = distillation

        if self.num_classes <= 0:
            raise ValueError(
                f'num_classes={num_classes} must be a positive integer')

        self.head = nn.Linear(self.in_channels, self.num_classes)
        if self.dist:
            self.dist_head = nn.Linear(self.in_channels, self.num_classes)

    def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The forward process."""
        pre_logits = self.pre_logits(feats)
        # The final classification head.
        cls_score = self.head(pre_logits)

        if self.dist:
            cls_score = (cls_score + self.dist_head(pre_logits)) / 2
        return cls_score

    def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The process before the final classification head.

        The input ``feats`` is a tuple of tensor, and each tensor is the
        feature of a backbone stage. In :obj`EfficientFormerClsHead`, we just
        obtain the feature of the last stage.
        """
        # The EfficientFormerClsHead doesn't have other module, just return
        # after unpacking.
        return feats[-1]

    def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample],
             **kwargs) -> dict:
        """Calculate losses from the classification score.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
                Multiple stage inputs are acceptable but only the last stage
                will be used to classify. The shape of every item should be
                ``(num_samples, num_classes)``.
            data_samples (List[DataSample]): The annotation data of
                every samples.
            **kwargs: Other keyword arguments to forward the loss module.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        if self.dist:
            raise NotImplementedError(
                "MMPretrain doesn't support to train"
                ' the distilled version EfficientFormer.')
        else:
            return super().loss(feats, data_samples, **kwargs)
