# Copyright (c) OpenMMLab. All rights reserved.
import copy
import fnmatch
import os.path as osp
import re
import warnings
from os import PathLike
from pathlib import Path
from typing import List, Tuple, Union

from mmengine.config import Config
from modelindex.load_model_index import load
from modelindex.models.Model import Model


class ModelHub:
    """A hub to host the meta information of all pre-defined models."""
    _models_dict = {}
    __mmpretrain_registered = False

    @classmethod
    def register_model_index(cls,
                             model_index_path: Union[str, PathLike],
                             config_prefix: Union[str, PathLike, None] = None):
        """Parse the model-index file and register all models.

        Args:
            model_index_path (str | PathLike): The path of the model-index
                file.
            config_prefix (str | PathLike | None): The prefix of all config
                file paths in the model-index file.
        """
        model_index = load(str(model_index_path))
        model_index.build_models_with_collections()

        for metainfo in model_index.models:
            model_name = metainfo.name.lower()
            if metainfo.name in cls._models_dict:
                raise ValueError(
                    'The model name {} is conflict in {} and {}.'.format(
                        model_name, osp.abspath(metainfo.filepath),
                        osp.abspath(cls._models_dict[model_name].filepath)))
            metainfo.config = cls._expand_config_path(metainfo, config_prefix)
            cls._models_dict[model_name] = metainfo

    @classmethod
    def get(cls, model_name):
        """Get the model's metainfo by the model name.

        Args:
            model_name (str): The name of model.

        Returns:
            modelindex.models.Model: The metainfo of the specified model.
        """
        cls._register_mmpretrain_models()
        # lazy load config
        metainfo = copy.deepcopy(cls._models_dict.get(model_name.lower()))
        if metainfo is None:
            raise ValueError(
                f'Failed to find model "{model_name}". please use '
                '`mmpretrain.list_models` to get all available names.')
        if isinstance(metainfo.config, str):
            metainfo.config = Config.fromfile(metainfo.config)
        return metainfo

    @staticmethod
    def _expand_config_path(metainfo: Model,
                            config_prefix: Union[str, PathLike] = None):
        if config_prefix is None:
            config_prefix = osp.dirname(metainfo.filepath)

        if metainfo.config is None or osp.isabs(metainfo.config):
            config_path: str = metainfo.config
        else:
            config_path = osp.abspath(osp.join(config_prefix, metainfo.config))

        return config_path

    @classmethod
    def _register_mmpretrain_models(cls):
        # register models in mmpretrain
        if not cls.__mmpretrain_registered:
            from importlib_metadata import distribution
            root = distribution('mmpretrain').locate_file('mmpretrain')
            model_index_path = root / '.mim' / 'model-index.yml'
            ModelHub.register_model_index(
                model_index_path, config_prefix=root / '.mim')
            cls.__mmpretrain_registered = True

    @classmethod
    def has(cls, model_name):
        """Whether a model name is in the ModelHub."""
        return model_name in cls._models_dict


def get_model(model: Union[str, Config],
              pretrained: Union[str, bool] = False,
              device=None,
              device_map=None,
              offload_folder=None,
              url_mapping: Tuple[str, str] = None,
              **kwargs):
    """Get a pre-defined model or create a model from config.

    Args:
        model (str | Config): The name of model, the config file path or a
            config instance.
        pretrained (bool | str): When use name to specify model, you can
            use ``True`` to load the pre-defined pretrained weights. And you
            can also use a string to specify the path or link of weights to
            load. Defaults to False.
        device (str | torch.device | None): Transfer the model to the target
            device. Defaults to None.
        device_map (str | dict | None): A map that specifies where each
            submodule should go. It doesn't need to be refined to each
            parameter/buffer name, once a given module name is inside, every
            submodule of it will be sent to the same device. You can use
            `device_map="auto"` to automatically generate the device map.
            Defaults to None.
        offload_folder (str | None): If the `device_map` contains any value
            `"disk"`, the folder where we will offload weights.
        url_mapping (Tuple[str, str], optional): The mapping of pretrained
            checkpoint link. For example, load checkpoint from a local dir
            instead of download by ``('https://.*/', './checkpoint')``.
            Defaults to None.
        **kwargs: Other keyword arguments of the model config.

    Returns:
        mmengine.model.BaseModel: The result model.

    Examples:
        Get a ResNet-50 model and extract images feature:

        >>> import torch
        >>> from mmpretrain import get_model
        >>> inputs = torch.rand(16, 3, 224, 224)
        >>> model = get_model('resnet50_8xb32_in1k', pretrained=True, backbone=dict(out_indices=(0, 1, 2, 3)))
        >>> feats = model.extract_feat(inputs)
        >>> for feat in feats:
        ...     print(feat.shape)
        torch.Size([16, 256])
        torch.Size([16, 512])
        torch.Size([16, 1024])
        torch.Size([16, 2048])

        Get Swin-Transformer model with pre-trained weights and inference:

        >>> from mmpretrain import get_model, inference_model
        >>> model = get_model('swin-base_16xb64_in1k', pretrained=True)
        >>> result = inference_model(model, 'demo/demo.JPEG')
        >>> print(result['pred_class'])
        'sea snake'
    """  # noqa: E501
    if device_map is not None:
        from .utils import dispatch_model
        dispatch_model._verify_require()

    metainfo = None
    if isinstance(model, Config):
        config = copy.deepcopy(model)
        if pretrained is True and 'load_from' in config:
            pretrained = config.load_from
    elif isinstance(model, (str, PathLike)) and Path(model).suffix == '.py':
        config = Config.fromfile(model)
        if pretrained is True and 'load_from' in config:
            pretrained = config.load_from
    elif isinstance(model, str):
        metainfo = ModelHub.get(model)
        config = metainfo.config
        if pretrained is True and metainfo.weights is not None:
            pretrained = metainfo.weights
    else:
        raise TypeError('model must be a name, a path or a Config object, '
                        f'but got {type(config)}')

    if pretrained is True:
        warnings.warn('Unable to find pre-defined checkpoint of the model.')
        pretrained = None
    elif pretrained is False:
        pretrained = None

    if kwargs:
        config.merge_from_dict({'model': kwargs})
    config.model.setdefault('data_preprocessor',
                            config.get('data_preprocessor', None))

    from mmengine.registry import DefaultScope

    from mmpretrain.registry import MODELS
    with DefaultScope.overwrite_default_scope('mmpretrain'):
        model = MODELS.build(config.model)

    dataset_meta = {}
    if pretrained:
        # Mapping the weights to GPU may cause unexpected video memory leak
        # which refers to https://github.com/open-mmlab/mmdetection/pull/6405
        from mmengine.runner import load_checkpoint
        if url_mapping is not None:
            pretrained = re.sub(url_mapping[0], url_mapping[1], pretrained)
        checkpoint = load_checkpoint(model, pretrained, map_location='cpu')
        if 'dataset_meta' in checkpoint.get('meta', {}):
            # mmpretrain 1.x
            dataset_meta = checkpoint['meta']['dataset_meta']
        elif 'CLASSES' in checkpoint.get('meta', {}):
            # mmcls 0.x
            dataset_meta = {'classes': checkpoint['meta']['CLASSES']}

    if len(dataset_meta) == 0 and 'test_dataloader' in config:
        from mmpretrain.registry import DATASETS
        dataset_class = DATASETS.get(config.test_dataloader.dataset.type)
        dataset_meta = getattr(dataset_class, 'METAINFO', {})

    if device_map is not None:
        model = dispatch_model(
            model, device_map=device_map, offload_folder=offload_folder)
    elif device is not None:
        model.to(device)

    model._dataset_meta = dataset_meta  # save the dataset meta
    model._config = config  # save the config in the model
    model._metainfo = metainfo  # save the metainfo in the model
    model.eval()
    return model


def init_model(config, checkpoint=None, device=None, **kwargs):
    """Initialize a classifier from config file (deprecated).

    It's only for compatibility, please use :func:`get_model` instead.

    Args:
        config (str | :obj:`mmengine.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        device (str | torch.device | None): Transfer the model to the target
            device. Defaults to None.
        **kwargs: Other keyword arguments of the model config.

    Returns:
        nn.Module: The constructed model.
    """
    return get_model(config, checkpoint, device, **kwargs)


def list_models(pattern=None, exclude_patterns=None, task=None) -> List[str]:
    """List all models available in MMPretrain.

    Args:
        pattern (str | None): A wildcard pattern to match model names.
            Defaults to None.
        exclude_patterns (list | None): A list of wildcard patterns to
            exclude names from the matched names. Defaults to None.
        task (str | none): The evaluation task of the model.

    Returns:
        List[str]: a list of model names.

    Examples:
        List all models:

        >>> from mmpretrain import list_models
        >>> list_models()

        List ResNet-50 models on ImageNet-1k dataset:

        >>> from mmpretrain import list_models
        >>> list_models('resnet*in1k')
        ['resnet50_8xb32_in1k',
         'resnet50_8xb32-fp16_in1k',
         'resnet50_8xb256-rsb-a1-600e_in1k',
         'resnet50_8xb256-rsb-a2-300e_in1k',
         'resnet50_8xb256-rsb-a3-100e_in1k']

        List Swin-Transformer models trained from stratch and exclude
        Swin-Transformer-V2 models:

        >>> from mmpretrain import list_models
        >>> list_models('swin', exclude_patterns=['swinv2', '*-pre'])
        ['swin-base_16xb64_in1k',
         'swin-base_3rdparty_in1k',
         'swin-base_3rdparty_in1k-384',
         'swin-large_8xb8_cub-384px',
         'swin-small_16xb64_in1k',
         'swin-small_3rdparty_in1k',
         'swin-tiny_16xb64_in1k',
         'swin-tiny_3rdparty_in1k']

        List all EVA models for image classification task.

        >>> from mmpretrain import list_models
        >>> list_models('eva', task='Image Classification')
        ['eva-g-p14_30m-in21k-pre_3rdparty_in1k-336px',
         'eva-g-p14_30m-in21k-pre_3rdparty_in1k-560px',
         'eva-l-p14_mim-in21k-pre_3rdparty_in1k-196px',
         'eva-l-p14_mim-in21k-pre_3rdparty_in1k-336px',
         'eva-l-p14_mim-pre_3rdparty_in1k-196px',
         'eva-l-p14_mim-pre_3rdparty_in1k-336px']
    """
    ModelHub._register_mmpretrain_models()
    matches = set(ModelHub._models_dict.keys())

    if pattern is not None:
        # Always match keys with any postfix.
        matches = set(fnmatch.filter(matches, pattern + '*'))

    exclude_patterns = exclude_patterns or []
    for exclude_pattern in exclude_patterns:
        exclude = set(fnmatch.filter(matches, exclude_pattern + '*'))
        matches = matches - exclude

    if task is not None:
        task_matches = []
        for key in matches:
            metainfo = ModelHub._models_dict[key]
            if metainfo.results is None and task == 'null':
                task_matches.append(key)
            elif metainfo.results is None:
                continue
            elif task in [result.task for result in metainfo.results]:
                task_matches.append(key)
        matches = task_matches

    return sorted(list(matches))


def inference_model(model, *args, **kwargs):
    """Inference an image with the inferencer.

    Automatically select inferencer to inference according to the type of
    model. It's a shortcut for a quick start, and for advanced usage, please
    use the correspondding inferencer class.

    Here is the mapping from task to inferencer:

    - Image Classification: :class:`ImageClassificationInferencer`
    - Image Retrieval: :class:`ImageRetrievalInferencer`
    - Image Caption: :class:`ImageCaptionInferencer`
    - Visual Question Answering: :class:`VisualQuestionAnsweringInferencer`
    - Visual Grounding: :class:`VisualGroundingInferencer`
    - Text-To-Image Retrieval: :class:`TextToImageRetrievalInferencer`
    - Image-To-Text Retrieval: :class:`ImageToTextRetrievalInferencer`
    - NLVR: :class:`NLVRInferencer`

    Args:
        model (BaseModel | str | Config): The loaded model, the model
            name or the config of the model.
        *args: Positional arguments to call the inferencer.
        **kwargs: Other keyword arguments to initialize and call the
            correspondding inferencer.

    Returns:
        result (dict): The inference results.
    """  # noqa: E501
    from mmengine.model import BaseModel

    if isinstance(model, BaseModel):
        metainfo = getattr(model, '_metainfo', None)
    else:
        metainfo = ModelHub.get(model)

    from inspect import signature

    from .image_caption import ImageCaptionInferencer
    from .image_classification import ImageClassificationInferencer
    from .image_retrieval import ImageRetrievalInferencer
    from .multimodal_retrieval import (ImageToTextRetrievalInferencer,
                                       TextToImageRetrievalInferencer)
    from .nlvr import NLVRInferencer
    from .visual_grounding import VisualGroundingInferencer
    from .visual_question_answering import VisualQuestionAnsweringInferencer
    task_mapping = {
        'Image Classification': ImageClassificationInferencer,
        'Image Retrieval': ImageRetrievalInferencer,
        'Image Caption': ImageCaptionInferencer,
        'Visual Question Answering': VisualQuestionAnsweringInferencer,
        'Visual Grounding': VisualGroundingInferencer,
        'Text-To-Image Retrieval': TextToImageRetrievalInferencer,
        'Image-To-Text Retrieval': ImageToTextRetrievalInferencer,
        'NLVR': NLVRInferencer,
    }

    inferencer_type = None

    if metainfo is not None and metainfo.results is not None:
        tasks = set(result.task for result in metainfo.results)
        inferencer_type = [
            task_mapping.get(task) for task in tasks if task in task_mapping
        ]
        if len(inferencer_type) > 1:
            inferencer_names = [cls.__name__ for cls in inferencer_type]
            warnings.warn('The model supports multiple tasks, auto select '
                          f'{inferencer_names[0]}, you can also use other '
                          f'inferencer {inferencer_names} directly.')
        inferencer_type = inferencer_type[0]

    if inferencer_type is None:
        raise NotImplementedError('No available inferencer for the model')

    init_kwargs = {
        k: kwargs.pop(k)
        for k in list(kwargs)
        if k in signature(inferencer_type).parameters.keys()
    }

    inferencer = inferencer_type(model, **init_kwargs)
    return inferencer(*args, **kwargs)[0]
