from opencompass.registry import ICL_EVALUATORS
from opencompass.utils.text_postprocessors import general_postprocess

from .icl_base_evaluator import BaseEvaluator


@ICL_EVALUATORS.register_module()
class EMEvaluator(BaseEvaluator):
    """Exact match evaluator."""

    def __init__(self) -> None:
        super().__init__()

    def score(self, predictions, references):
        if len(predictions) != len(references):
            return {
                'error': 'predictions and references have different '
                'length'
            }
        predictions = [
            general_postprocess(prediction) for prediction in predictions
        ]
        processed_answers = [[general_postprocess(j) for j in i]
                             for i in references]

        cnt = 0
        details = []
        for pred, ans, origin_ans in zip(predictions, processed_answers,
                                         references):
            answers = list(set(ans + origin_ans))
            
            detail = {'pred': pred, 'answer': answers}
            # old evaluation method used by opencompass
            # if pred in ans or pred in origin_ans:
            #     cnt += 1
            #     detail['correct'] = True
            # else:
            #     detail['correct'] = False
            
            # new evaluation method added by les
            for option in answers:
                if option in pred:
                    cnt += 1
                    detail['correct'] = True
                    break
            else:
                detail['correct'] = False
            details.append(detail)

        score = cnt / len(predictions) * 100
        
        return {'score': score, 'details': details}
