# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Tuple

import torch

from mmaction.registry import MODELS
from mmaction.utils.typing import OptSampleList
from .base import BaseRecognizer


@MODELS.register_module()
class MMRecognizer3D(BaseRecognizer):
    """Multi-modal 3D recognizer model framework."""

    def extract_feat(self,
                     inputs: Dict[str, torch.Tensor],
                     stage: str = 'backbone',
                     data_samples: OptSampleList = None,
                     test_mode: bool = False) -> Tuple:
        """Extract features.

        Args:
            inputs (dict[str, torch.Tensor]): The multi-modal input data.
            stage (str): Which stage to output the feature.
                Defaults to ``'backbone'``.
            data_samples (list[:obj:`ActionDataSample`], optional): Action data
                samples, which are only needed in training. Defaults to None.
            test_mode (bool): Whether in test mode. Defaults to False.

        Returns:
                tuple[torch.Tensor]: The extracted features.
                dict: A dict recording the kwargs for downstream
                    pipeline.
        """
        # [N, num_views, C, T, H, W] ->
        # [N * num_views, C, T, H, W]
        for m, m_data in inputs.items():
            m_data = m_data.reshape((-1, ) + m_data.shape[2:])
            inputs[m] = m_data

        # Record the kwargs required by `loss` and `predict`
        loss_predict_kwargs = dict()

        x = self.backbone(**inputs)
        if stage == 'backbone':
            return x, loss_predict_kwargs

        if self.with_cls_head and stage == 'head':
            x = self.cls_head(x, **loss_predict_kwargs)
            return x, loss_predict_kwargs
