from typing import (
    List,
    Tuple,
)

from src.evaluator.metric.base_metric import BaseMetric
from src.schema import Prediction


class EMMetric(BaseMetric):
    def __init__(self) -> None:
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def update(self, prediction: Prediction) -> Tuple[Prediction, List[float]]:
        answers = prediction.qa.answers
        pred_answers = prediction.pred_answers

        curr_scores = []
        for answer, pred_answer in zip(answers, pred_answers):
            assert answer.answer_id == pred_answer.answer_id

            rubric_question = answer.rubric_question
            rubric_question_answer_dict = pred_answer.rubric_question_answer_dict
            em = rubric_question_answer_dict[rubric_question]

            pred_answer.em = em
            curr_scores.append(em)

        prediction.em = sum(curr_scores) / len(curr_scores) if curr_scores else 0.0
        self.scores.extend(curr_scores)

        return prediction, curr_scores

    def compute(self) -> float:
        return sum(self.scores) / len(self.scores) if self.scores else 0.0
