from src.evaluator.metric.base_metric import BaseMetric
from src.evaluator.metric.locomo_reference import bert_score
from src.schema import InferenceResult


class EMMetric(BaseMetric):
    def __init__(self) -> None:
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def update(self, res: InferenceResult) -> None:
        answer_sets = res.qa.answer_sets
        pred_answer_sets = res.pred_answer_sets

        assert pred_answer_sets[0].elements is not None, "infer_mode must be 'both' or 'gen' to compute EM"

        for answer_set, pred_answer_set in zip(answer_sets, pred_answer_sets):
            assert answer_set.answer_set_id == pred_answer_set.answer_set_id

            em = 1.0 if answer_set.elements == pred_answer_set.elements else 0.0
            # em = bert_score(list(answer_set.elements)[0], list(pred_answer_set.elements)[0])
            self.scores.append(em)

    def compute(self) -> float:
        if len(self.scores) == 0:
            return 0.0
        return sum(self.scores) / len(self.scores)
