# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Tuple

import torch
import torch.nn as nn

from mmpretrain.registry import MODELS
from .vision_transformer_head import VisionTransformerClsHead


@MODELS.register_module()
class DeiTClsHead(VisionTransformerClsHead):
    """Distilled Vision Transformer classifier head.

    Comparing with the :class:`VisionTransformerClsHead`, this head adds an
    extra linear layer to handle the dist token. The final classification score
    is the average of both linear transformation results of ``cls_token`` and
    ``dist_token``.

    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        hidden_dim (int, optional): Number of the dimensions for hidden layer.
            Defaults to None, which means no extra hidden layer.
        act_cfg (dict): The activation config. Only available during
            pre-training. Defaults to ``dict(type='Tanh')``.
        init_cfg (dict): The extra initialization configs. Defaults to
            ``dict(type='Constant', layer='Linear', val=0)``.
    """

    def _init_layers(self):
        """"Init extra hidden linear layer to handle dist token if exists."""
        super(DeiTClsHead, self)._init_layers()
        if self.hidden_dim is None:
            head_dist = nn.Linear(self.in_channels, self.num_classes)
        else:
            head_dist = nn.Linear(self.hidden_dim, self.num_classes)
        self.layers.add_module('head_dist', head_dist)

    def pre_logits(self,
                   feats: Tuple[List[torch.Tensor]]) -> Tuple[torch.Tensor]:
        """The process before the final classification head.

        The input ``feats`` is a tuple of list of tensor, and each tensor is
        the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
        feature of the last stage and forward in hidden layer if exists.
        """
        feat = feats[-1]  # Obtain feature of the last scale.
        # For backward-compatibility with the previous ViT output
        if len(feat) == 3:
            _, cls_token, dist_token = feat
        else:
            cls_token, dist_token = feat
        if self.hidden_dim is None:
            return cls_token, dist_token
        else:
            cls_token = self.layers.act(self.layers.pre_logits(cls_token))
            dist_token = self.layers.act(self.layers.pre_logits(dist_token))
            return cls_token, dist_token

    def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
        """The forward process."""
        if self.training:
            warnings.warn('MMPretrain cannot train the '
                          'distilled version DeiT.')
        cls_token, dist_token = self.pre_logits(feats)
        # The final classification head.
        cls_score = (self.layers.head(cls_token) +
                     self.layers.head_dist(dist_token)) / 2
        return cls_score
