from src.evaluator.metric.base_metric import BaseMetric
from src.schema import InferenceResult
from src.utils.set import iou


class IPSMetric(BaseMetric):
    def __init__(self) -> None:
        self.ips_list = []

    def reset(self) -> None:
        self.ips_list = []

    def update(self, res: InferenceResult) -> None:
        answer_sets = res.qa.answer_sets
        pred_answer_sets = res.pred_answer_sets

        assert pred_answer_sets[0].elements is not None, "infer_mode must be 'both' or 'gen' to compute IPS"

        for answer_set, pred_answer_set in zip(answer_sets, pred_answer_sets):
            assert answer_set.answer_set_id == pred_answer_set.answer_set_id

            ips = iou(answer_set.elements, pred_answer_set.elements)
            self.ips_list.append(ips)

    def compute(self) -> float:
        if len(self.ips_list) == 0:
            return 0.0
        avg_ips = sum(self.ips_list) / len(self.ips_list)
        return avg_ips
