# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import tempfile
from typing import List, Optional

from mmengine.evaluator import BaseMetric
from mmengine.utils import track_iter_progress

from mmpretrain.registry import METRICS
from mmpretrain.utils import require

try:
    from pycocoevalcap.eval import COCOEvalCap
    from pycocotools.coco import COCO
except ImportError:
    COCOEvalCap = None
    COCO = None


@METRICS.register_module()
class COCOCaption(BaseMetric):
    """Coco Caption evaluation wrapper.

    Save the generated captions and transform into coco format.
    Calling COCO API for caption metrics.

    Args:
        ann_file (str): the path for the COCO format caption ground truth
            json file, load for evaluations.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Should be modified according to the
            `retrieval_type` for unambiguous results. Defaults to TR.
    """

    @require('pycocoevalcap')
    def __init__(self,
                 ann_file: str,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None):
        super().__init__(collect_device=collect_device, prefix=prefix)
        self.ann_file = ann_file

    def process(self, data_batch, data_samples):
        """Process one batch of data samples.

        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch: A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """

        for data_sample in data_samples:
            result = dict()

            result['caption'] = data_sample.get('pred_caption')
            result['image_id'] = int(data_sample.get('image_id'))

            # Save the result to `self.results`.
            self.results.append(result)

    def compute_metrics(self, results: List):
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
        # NOTICE: don't access `self.results` from the method.

        with tempfile.TemporaryDirectory() as temp_dir:

            eval_result_file = save_result(
                result=results,
                result_dir=temp_dir,
                filename='m4-caption_pred',
                remove_duplicate='image_id',
            )

            coco_val = coco_caption_eval(eval_result_file, self.ann_file)

        return coco_val


def save_result(result, result_dir, filename, remove_duplicate=''):
    """Saving predictions as json file for evaluation."""

    # combine results from all processes
    result_new = []

    if remove_duplicate:
        result_new = []
        id_list = []
        for res in track_iter_progress(result):
            if res[remove_duplicate] not in id_list:
                id_list.append(res[remove_duplicate])
                result_new.append(res)
        result = result_new

    final_result_file_url = os.path.join(result_dir, '%s.json' % filename)
    print(f'result file saved to {final_result_file_url}')
    json.dump(result, open(final_result_file_url, 'w'))

    return final_result_file_url


def coco_caption_eval(results_file, ann_file):
    """Evaluation between gt json and prediction json files."""
    # create coco object and coco_result object
    coco = COCO(ann_file)
    coco_result = coco.loadRes(results_file)

    # create coco_eval object by taking coco and coco_result
    coco_eval = COCOEvalCap(coco, coco_result)

    # make sure the image ids are the same
    coco_eval.params['image_id'] = coco_result.getImgIds()

    # This will take some times at the first run
    coco_eval.evaluate()

    # print output evaluation scores
    for metric, score in coco_eval.eval.items():
        print(f'{metric}: {score:.3f}')

    return coco_eval.eval
