# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.registry import METRICS


@METRICS.register_module()
class ANLS(BaseMetric):
    """ANLS metric.

    Compute the Average Normalized Levenshtein Similarity(ANLS).

    Args:
        threshold (float): ANLS threshold used for determining if the answer
            has been correctly selected but not properly recognized,
            or on the contrary, the output is a wrong text selected from the
            options and given as an answer.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Should be modified according to the
            `retrieval_type` for unambiguous results. Defaults to TR.
    """
    default_prefix = 'ANLS'

    def __init__(self,
                 threshold: float = 0.5,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)
        self.threshold = threshold

    def process(self, data_batch, data_samples) -> None:
        """Process one batch of data samples.

        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch: A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """
        for sample in data_samples:
            gt_answer = sample.get('gt_answer')
            result = {
                'pred_answer': sample.get('pred_answer'),
                'gt_answer': gt_answer
            }

            self.results.append(result)

    def compute_metrics(self, results: List) -> dict:
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
        total_score = 0.
        for result in results:
            sample_score_list = []
            pred = ' '.join(result['pred_answer'].strip().lower().split())
            for gt in result['gt_answer']:
                gt = ' '.join(gt.strip().lower().split())
                dist = levenshtein_distance(gt, pred)
                length = max(
                    len(gt.upper()), len(result['pred_answer'].upper()))
                sample_score_list.append(0.0 if length == 0 else float(dist) /
                                         float(length))

            per_sample_score = 1. - min(sample_score_list)
            if per_sample_score < self.threshold:
                per_sample_score = 0.

            total_score += per_sample_score

        total_score = total_score / len(results)
        return {'ANLS': total_score}


def levenshtein_distance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2 + 1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1],
                                           distances_[-1])))
        distances = distances_
    return distances[-1]
