# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union

import torch
from mmcv.image import imread
from mmengine.config import Config
from mmengine.dataset import Compose, default_collate

from mmpretrain.registry import TRANSFORMS
from .base import BaseInferencer, InputType
from .model import list_models


class FeatureExtractor(BaseInferencer):
    """The inferencer for extract features.

    Args:
        model (BaseModel | str | Config): A model name or a path to the config
            file, or a :obj:`BaseModel` object. The model name can be found
            by ``FeatureExtractor.list_models()`` and you can also query it in
            :doc:`/modelzoo_statistics`.
        pretrained (str, optional): Path to the checkpoint. If None, it will
            try to find a pre-defined weight from the model you specified
            (only work if the ``model`` is a model name). Defaults to None.
        device (str, optional): Device to run inference. If None, the available
            device will be automatically used. Defaults to None.
        **kwargs: Other keyword arguments to initialize the model (only work if
            the ``model`` is a model name).

    Example:
        >>> from mmpretrain import FeatureExtractor
        >>> inferencer = FeatureExtractor('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
        >>> feats = inferencer('demo/demo.JPEG', stage='backbone')[0]
        >>> for feat in feats:
        >>>     print(feat.shape)
        torch.Size([256, 56, 56])
        torch.Size([512, 28, 28])
        torch.Size([1024, 14, 14])
        torch.Size([2048, 7, 7])
    """  # noqa: E501

    def __call__(self,
                 inputs: InputType,
                 batch_size: int = 1,
                 **kwargs) -> dict:
        """Call the inferencer.

        Args:
            inputs (str | array | list): The image path or array, or a list of
                images.
            batch_size (int): Batch size. Defaults to 1.
            **kwargs: Other keyword arguments accepted by the `extract_feat`
                method of the model.

        Returns:
            tensor | Tuple[tensor]: The extracted features.
        """
        ori_inputs = self._inputs_to_list(inputs)
        inputs = self.preprocess(ori_inputs, batch_size=batch_size)
        preds = []
        for data in inputs:
            preds.extend(self.forward(data, **kwargs))

        return preds

    @torch.no_grad()
    def forward(self, inputs: Union[dict, tuple], **kwargs):
        inputs = self.model.data_preprocessor(inputs, False)['inputs']
        outputs = self.model.extract_feat(inputs, **kwargs)

        def scatter(feats, index):
            if isinstance(feats, torch.Tensor):
                return feats[index]
            else:
                # Sequence of tensor
                return type(feats)([scatter(item, index) for item in feats])

        results = []
        for i in range(inputs.shape[0]):
            results.append(scatter(outputs, i))

        return results

    def _init_pipeline(self, cfg: Config) -> Callable:
        test_pipeline_cfg = cfg.test_dataloader.dataset.pipeline
        from mmpretrain.datasets import remove_transform

        # Image loading is finished in `self.preprocess`.
        test_pipeline_cfg = remove_transform(test_pipeline_cfg,
                                             'LoadImageFromFile')
        test_pipeline = Compose(
            [TRANSFORMS.build(t) for t in test_pipeline_cfg])
        return test_pipeline

    def preprocess(self, inputs: List[InputType], batch_size: int = 1):

        def load_image(input_):
            img = imread(input_)
            if img is None:
                raise ValueError(f'Failed to read image {input_}.')
            return dict(
                img=img,
                img_shape=img.shape[:2],
                ori_shape=img.shape[:2],
            )

        pipeline = Compose([load_image, self.pipeline])

        chunked_data = self._get_chunk_data(map(pipeline, inputs), batch_size)
        yield from map(default_collate, chunked_data)

    def visualize(self):
        raise NotImplementedError(
            "The FeatureExtractor doesn't support visualization.")

    def postprocess(self):
        raise NotImplementedError(
            "The FeatureExtractor doesn't need postprocessing.")

    @staticmethod
    def list_models(pattern: Optional[str] = None):
        """List all available model names.

        Args:
            pattern (str | None): A wildcard pattern to match model names.

        Returns:
            List[str]: a list of model names.
        """
        return list_models(pattern=pattern)
