from typing import (
    List,
    Tuple,
)

from src.evaluator.metric.base_metric import BaseMetric
from src.schema import Prediction


class IPSMetric(BaseMetric):
    def __init__(self) -> None:
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def _compute_iou(self, answer_vector: List[int], pred_answer_vector: List[int]) -> float:
        i = sum(a & b for a, b in zip(answer_vector, pred_answer_vector))
        u = sum(a | b for a, b in zip(answer_vector, pred_answer_vector))
        if u == 0:  # TODO: check here
            return 1.0
        return i / u

    def update(self, prediction: Prediction) -> Tuple[Prediction, List[float]]:
        answers = prediction.qa.answers
        pred_answers = prediction.pred_answers

        curr_scores = []
        for answer, pred_answer in zip(answers, pred_answers):
            assert answer.answer_id == pred_answer.answer_id

            rubric_question = answer.rubric_question
            rubric_question_answer_dict = pred_answer.rubric_question_answer_dict
            answer_vector = [1 if rubric_question == k else 0 for k in rubric_question_answer_dict.keys()]
            pred_answer_vector = list(rubric_question_answer_dict.values())
            ips = self._compute_iou(answer_vector, pred_answer_vector)

            pred_answer.ips = ips
            curr_scores.append(ips)

        prediction.ips = sum(curr_scores) / len(curr_scores) if curr_scores else 0.0
        self.scores.extend(curr_scores)

        return prediction, curr_scores

    def compute(self) -> float:
        avg_ips = sum(self.scores) / len(self.scores) if self.scores else 0.0
        return avg_ips
