from typing import (
    Dict,
    List,
    Set,
    Tuple,
    Union,
)

from src.evaluator.metric.base_metric import BaseMetric
from src.schema import (
    AnswerSet,
    GPSReport,
    InferenceResult,
    PredictedAnswerSet,
)
from src.utils.set import (
    FULL_SET_ELEMENT,
    iou,
    union,
    union_of_sets,
)


class GPSMetric(BaseMetric):
    def __init__(self) -> None:
        self.gps_list = []
        self.gps_u_list = []

    def reset(self) -> None:
        self.gps_list = []
        self.gps_u_list = []

    def update(self, res: InferenceResult) -> None:
        answer_sets = res.qa.answer_sets
        pred_answer_sets = res.pred_answer_sets

        full_answer_set = union_of_sets([answer_set.elements for answer_set in answer_sets])

        answer_user_group = self._get_answer_user_group(
            answer_sets=answer_sets,
            full_answer_set=full_answer_set,
        )
        pred_answer_user_group = self._get_answer_user_group(
            answer_sets=pred_answer_sets,
            full_answer_set=full_answer_set,
        )

        gps_dict = {}
        for element in union(full_answer_set, {FULL_SET_ELEMENT}):
            gps = iou(answer_user_group[element], pred_answer_user_group[element])
            gps_dict[element] = gps

        gps_values = list(gps_dict.values())
        self.gps_list.append(sum(gps_values) / len(gps_values))
        self.gps_u_list.append(gps_dict[FULL_SET_ELEMENT])

    def compute(self) -> Tuple[float, float]:
        if not self.gps_list or not self.gps_u_list:
            return 0.0, 0.0

        avg_gps = sum(self.gps_list) / len(self.gps_list)
        avg_gps_u = sum(self.gps_u_list) / len(self.gps_u_list)
        return GPSReport(gps=avg_gps, gps_u=avg_gps_u)

    def _get_answer_user_group(
        self,
        answer_sets: List[Union[AnswerSet, PredictedAnswerSet]],
        full_answer_set: Set[str],
    ) -> Dict[str, Set[str]]:
        answer_group = {element: set() for element in union(full_answer_set, {FULL_SET_ELEMENT})}

        for answer_set in answer_sets:
            elements = answer_set.elements
            user_id = answer_set.user_id

            for curr_element in full_answer_set:
                if elements == {curr_element}:
                    answer_group[curr_element].add(user_id)

            if elements == full_answer_set:
                answer_group[FULL_SET_ELEMENT].add(user_id)

        return answer_group
