# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.models import ImageClassifier as _ImageClassifier


@MODELS.register_module(force=True)
class ImageClassifier(_ImageClassifier):
    """Image classifiers for supervised classification task."""

    def extract_feat(self, inputs, stage='neck'):
        """Extract features from the input tensor with shape (N, C, ...)."""

        assert stage in ['backbone', 'neck', 'pre_logits'], \
            (f'Invalid output stage "{stage}", please choose from "backbone", '
             '"neck" and "pre_logits"')

        x = self.backbone(inputs)

        if stage == 'backbone':
            return x

        if isinstance(x, tuple) and isinstance(x[-1], dict):
            assert len(x) == 2
            x, loss = x
        else:
            loss = None

        if self.with_neck:
            x = self.neck(x)
        if stage == 'neck':
            return (x, loss) if self.training else x

        assert self.with_head and hasattr(self.head, 'pre_logits'), \
            "No head or the head doesn't implement `pre_logits` method."
        x = self.head.pre_logits(x)
        return (x, loss) if self.training else x

    def loss(self, inputs: torch.Tensor,
             data_samples: List[DataSample]) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample]): The annotation data of
                every samples.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        losses = dict()
        feats, aux_loss = self.extract_feat(inputs)
        cls_loss = self.head.loss(feats, data_samples)
        losses['cls_loss'] = cls_loss['loss']
        if aux_loss is not None:
            losses.update(aux_loss)
        return losses
