from typing import (
    List,
    Tuple,
)

from src.evaluator.metric.base_metric import BaseMetric
from src.schema import Prediction


class RecallAtKMetric(BaseMetric):
    def __init__(self, k: int = 5) -> None:
        self.k = k
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def update(self, prediction: Prediction) -> Tuple[Prediction, List[float]]:
        answers = prediction.qa.answers
        pred_answers = prediction.pred_answers

        curr_scores = []
        for answer, pred_answer in zip(answers, pred_answers):
            assert answer.answer_id == pred_answer.answer_id

            ref_memory_ids = answer.ref_memory_ids
            pred_ret_memory_ids = pred_answer.ret_memory_ids

            if pred_ret_memory_ids is None:  # TODO: check here
                recall_score = 1.0
                pred_answer.recall_at_k = recall_score
                curr_scores.append(recall_score)
                continue

            top_k_pred_ret_memory_ids = pred_ret_memory_ids[:self.k]

            if len(ref_memory_ids) == 0:
                recall_score = 1.0
            else:
                recall_score = len(set(ref_memory_ids).intersection(set(top_k_pred_ret_memory_ids))) / len(ref_memory_ids)  # TODO: check here

            pred_answer.recall_at_k = recall_score
            curr_scores.append(recall_score)

        prediction.recall_at_k = sum(curr_scores) / len(curr_scores) if curr_scores else 0.0
        self.scores.extend(curr_scores)
        
        return prediction, curr_scores

    def compute(self) -> float:
        return sum(self.scores) / len(self.scores) if self.scores else 0.0
