# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, List, Optional, Union

import mmengine.dist as dist
import torch
import torch.nn as nn
from mmengine.runner import Runner
from torch.utils.data import DataLoader

from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .base import BaseRetriever


@MODELS.register_module()
class ImageToImageRetriever(BaseRetriever):
    """Image To Image Retriever for supervised retrieval task.

    Args:
        image_encoder (Union[dict, List[dict]]): Encoder for extracting
            features.
        prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be
            retrieved. The following four types are supported.

            - DataLoader: The original dataloader serves as the prototype.
            - dict: The configuration to construct Dataloader.
            - str: The path of the saved vector.
            - torch.Tensor: The saved tensor whose dimension should be dim.

        head (dict, optional): The head module to calculate loss from
            processed features. See :mod:`mmpretrain.models.heads`. Notice
            that if the head is not set, `loss` method cannot be used.
            Defaults to None.
        similarity_fn (Union[str, Callable]): The way that the similarity
            is calculated. If `similarity` is callable, it is used directly
            as the measure function. If it is a string, the appropriate
            method will be used.  The larger the calculated value, the
            greater the similarity. Defaults to "cosine_similarity".
        train_cfg (dict, optional): The training setting. The acceptable
            fields are:

            - augments (List[dict]): The batch augmentation methods to use.
              More details can be found in
              :mod:`mmpretrain.model.utils.augment`.

            Defaults to None.
        data_preprocessor (dict, optional): The config for preprocessing input
            data. If None or no specified type, it will use
            "ClsDataPreprocessor" as type. See :class:`ClsDataPreprocessor` for
            more details. Defaults to None.
        topk (int): Return the topk of the retrieval result. `-1` means
            return all. Defaults to -1.
        init_cfg (dict, optional): the config to control the initialization.
            Defaults to None.
    """

    def __init__(self,
                 image_encoder: Union[dict, List[dict]],
                 prototype: Union[DataLoader, dict, str, torch.Tensor],
                 head: Optional[dict] = None,
                 pretrained: Optional[str] = None,
                 similarity_fn: Union[str, Callable] = 'cosine_similarity',
                 train_cfg: Optional[dict] = None,
                 data_preprocessor: Optional[dict] = None,
                 topk: int = -1,
                 init_cfg: Optional[dict] = None):

        if data_preprocessor is None:
            data_preprocessor = {}
        # The build process is in MMEngine, so we need to add scope here.
        data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')

        if train_cfg is not None and 'augments' in train_cfg:
            # Set batch augmentations by `train_cfg`
            data_preprocessor['batch_augments'] = train_cfg

        super(ImageToImageRetriever, self).__init__(
            init_cfg=init_cfg, data_preprocessor=data_preprocessor)

        if not isinstance(image_encoder, nn.Module):
            image_encoder = MODELS.build(image_encoder)
        if head is not None and not isinstance(head, nn.Module):
            head = MODELS.build(head)

        self.image_encoder = image_encoder
        self.head = head

        self.similarity = similarity_fn

        assert isinstance(prototype, (str, torch.Tensor, dict, DataLoader)), (
            'The `prototype` in  `ImageToImageRetriever` must be a path, '
            'a torch.Tensor, a dataloader or a dataloader dict format config.')
        self.prototype = prototype
        self.prototype_inited = False
        self.topk = topk

    @property
    def similarity_fn(self):
        """Returns a function that calculates the similarity."""
        # If self.similarity_way is callable, return it directly
        if isinstance(self.similarity, Callable):
            return self.similarity

        if self.similarity == 'cosine_similarity':
            # a is a tensor with shape (N, C)
            # b is a tensor with shape (M, C)
            # "cosine_similarity" will get the matrix of similarity
            # with shape (N, M).
            # The higher the score is, the more similar is
            return lambda a, b: torch.cosine_similarity(
                a.unsqueeze(1), b.unsqueeze(0), dim=-1)
        else:
            raise RuntimeError(f'Invalid function "{self.similarity_fn}".')

    def forward(self,
                inputs: torch.Tensor,
                data_samples: Optional[List[DataSample]] = None,
                mode: str = 'tensor'):
        """The unified entry for a forward process in both training and test.

        The method should accept three modes: "tensor", "predict" and "loss":

        - "tensor": Forward the whole network and return tensor without any
          post-processing, same as a common nn.Module.
        - "predict": Forward and return the predictions, which are fully
          processed to a list of :obj:`DataSample`.
        - "loss": Forward and return a dict of losses according to the given
          inputs and data samples.

        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.

        Args:
            inputs (torch.Tensor, tuple): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. It's required if ``mode="loss"``.
                Defaults to None.
            mode (str): Return what kind of value. Defaults to 'tensor'.

        Returns:
            The return type depends on ``mode``.

            - If ``mode="tensor"``, return a tensor.
            - If ``mode="predict"``, return a list of
              :obj:`mmpretrain.structures.DataSample`.
            - If ``mode="loss"``, return a dict of tensor.
        """
        if mode == 'tensor':
            return self.extract_feat(inputs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples)
        elif mode == 'predict':
            return self.predict(inputs, data_samples)
        else:
            raise RuntimeError(f'Invalid mode "{mode}".')

    def extract_feat(self, inputs):
        """Extract features from the input tensor with shape (N, C, ...).

        Args:
            inputs (Tensor): A batch of inputs. The shape of it should be
                ``(num_samples, num_channels, *img_shape)``.
        Returns:
            Tensor: The output of encoder.
        """

        feat = self.image_encoder(inputs)
        return feat

    def loss(self, inputs: torch.Tensor,
             data_samples: List[DataSample]) -> dict:
        """Calculate losses from a batch of inputs and data samples.

        Args:
            inputs (torch.Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[DataSample]): The annotation data of
                every samples.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        feats = self.extract_feat(inputs)
        return self.head.loss(feats, data_samples)

    def matching(self, inputs: torch.Tensor):
        """Compare the prototype and calculate the similarity.

        Args:
            inputs (torch.Tensor): The input tensor with shape (N, C).
        Returns:
            dict: a dictionary of score and prediction label based on fn.
        """
        sim = self.similarity_fn(inputs, self.prototype_vecs)
        sorted_sim, indices = torch.sort(sim, descending=True, dim=-1)
        predictions = dict(
            score=sim, pred_label=indices, pred_score=sorted_sim)
        return predictions

    def predict(self,
                inputs: tuple,
                data_samples: Optional[List[DataSample]] = None,
                **kwargs) -> List[DataSample]:
        """Predict results from the extracted features.

        Args:
            inputs (tuple): The features extracted from the backbone.
            data_samples (List[DataSample], optional): The annotation
                data of every samples. Defaults to None.
            **kwargs: Other keyword arguments accepted by the ``predict``
                method of :attr:`head`.
        Returns:
            List[DataSample]: the raw data_samples with
                the predicted results
        """
        if not self.prototype_inited:
            self.prepare_prototype()

        feats = self.extract_feat(inputs)
        if isinstance(feats, tuple):
            feats = feats[-1]

        # Matching of similarity
        result = self.matching(feats)
        return self._get_predictions(result, data_samples)

    def _get_predictions(self, result, data_samples):
        """Post-process the output of retriever."""
        pred_scores = result['score']
        pred_labels = result['pred_label']
        if self.topk != -1:
            topk = min(self.topk, pred_scores.size()[-1])
            pred_labels = pred_labels[:, :topk]

        if data_samples is not None:
            for data_sample, score, label in zip(data_samples, pred_scores,
                                                 pred_labels):
                data_sample.set_pred_score(score).set_pred_label(label)
        else:
            data_samples = []
            for score, label in zip(pred_scores, pred_labels):
                data_samples.append(
                    DataSample().set_pred_score(score).set_pred_label(label))
        return data_samples

    def _get_prototype_vecs_from_dataloader(self, data_loader):
        """get prototype_vecs from dataloader."""
        self.eval()
        num = len(data_loader.dataset)

        prototype_vecs = None
        for data_batch in track_on_main_process(data_loader,
                                                'Prepare prototype'):
            data = self.data_preprocessor(data_batch, False)
            feat = self(**data)
            if isinstance(feat, tuple):
                feat = feat[-1]

            if prototype_vecs is None:
                dim = feat.shape[-1]
                prototype_vecs = torch.zeros(num, dim)
            for i, data_sample in enumerate(data_batch['data_samples']):
                sample_idx = data_sample.get('sample_idx')
                prototype_vecs[sample_idx] = feat[i]

        assert prototype_vecs is not None
        dist.all_reduce(prototype_vecs)
        return prototype_vecs

    def _get_prototype_vecs_from_path(self, proto_path):
        """get prototype_vecs from prototype path."""
        data = [None]
        if dist.is_main_process():
            data[0] = torch.load(proto_path)
        dist.broadcast_object_list(data, src=0)
        prototype_vecs = data[0]
        assert prototype_vecs is not None
        return prototype_vecs

    @torch.no_grad()
    def prepare_prototype(self):
        """Used in meta testing. This function will be called before the meta
        testing. Obtain the vector based on the prototype.

        - torch.Tensor: The prototype vector is the prototype
        - str: The path of the extracted feature path, parse data structure,
            and generate the prototype feature vector set
        - Dataloader or config: Extract and save the feature vectors according
            to the dataloader
        """
        device = next(self.image_encoder.parameters()).device
        if isinstance(self.prototype, torch.Tensor):
            prototype_vecs = self.prototype
        elif isinstance(self.prototype, str):
            prototype_vecs = self._get_prototype_vecs_from_path(self.prototype)
        elif isinstance(self.prototype, (dict, DataLoader)):
            loader = Runner.build_dataloader(self.prototype)
            prototype_vecs = self._get_prototype_vecs_from_dataloader(loader)

        self.register_buffer(
            'prototype_vecs', prototype_vecs.to(device), persistent=False)
        self.prototype_inited = True

    def dump_prototype(self, path):
        """Save the features extracted from the prototype to specific path.

        Args:
            path (str): Path to save feature.
        """
        if not self.prototype_inited:
            self.prepare_prototype()
        torch.save(self.prototype_vecs, path)
