from src.evaluator.metric.base_metric import BaseMetric
from src.evaluator.metric.locomo_reference import f1
from src.schema import InferenceResult


class F1Metric(BaseMetric):
    def __init__(self) -> None:
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def update(self, res: InferenceResult) -> None:
        answer_sets = res.qa.answer_sets
        pred_answer_sets = res.pred_answer_sets

        assert pred_answer_sets[0].elements is not None, "infer_mode must be 'both' or 'gen' to compute EM"

        for answer_set, pred_answer_set in zip(answer_sets, pred_answer_sets):
            assert answer_set.answer_set_id == pred_answer_set.answer_set_id
            assert len(answer_set.elements) == 1

            if len(pred_answer_set.elements) == 0:
                f1_value = f1(list(answer_set.elements)[0], "")
            else:
                f1_value = f1(list(answer_set.elements)[0], list(pred_answer_set.elements)[0])
            self.scores.append(f1_value)

    def compute(self) -> float:
        if len(self.scores) == 0:
            return 0.0
        return sum(self.scores) / len(self.scores)
