from src.evaluator.metric.base_metric import BaseMetric
from src.schema import InferenceResult


class RecallAtKMetric(BaseMetric):
    def __init__(self, k: int = 5) -> None:
        self.k = k
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def update(self, res: InferenceResult) -> None:
        answer_sets = res.qa.answer_sets
        pred_answer_sets = res.pred_answer_sets

        for answer_set, pred_answer_set in zip(answer_sets, pred_answer_sets):
            assert answer_set.answer_set_id == pred_answer_set.answer_set_id

            ref_memory_ids = answer_set.ref_memory_ids
            pred_ref_memory_ids = pred_answer_set.ref_memory_ids[: self.k]

            if len(ref_memory_ids) == 0:
                recall_score = 1.0
            else:
                recall_score = len(set(ref_memory_ids).intersection(set(pred_ref_memory_ids))) / len(ref_memory_ids)

            self.scores.append(recall_score)

    def compute(self) -> float:
        if len(self.scores) == 0:
            return 0.0
        return sum(self.scores) / len(self.scores)
