from typing import (
    List,
    Tuple,
)
from collections import Counter

from src.evaluator.metric.base_metric import BaseMetric
from src.schema import (
    Prediction,
    InstFailureMode,
    DistFailureMode,
)


class DPSMetric(BaseMetric):
    def __init__(self) -> None:
        self.scores = []

    def reset(self) -> None:
        self.scores = []

    def _compute_iou(self, answer_vector: List[int], pred_answer_vector: List[int]) -> float:
        i = sum(a & b for a, b in zip(answer_vector, pred_answer_vector))
        u = sum(a | b for a, b in zip(answer_vector, pred_answer_vector))
        assert u > 0, "Union cardinality became zero, which should be impossible under one-hot ground truth."
        return i / u

    def _compute_weight(self, pred_answer_vector: List[int], mean_pred_answer_vector: List[int]) -> float:
        weight = sum((a - b) ** 2 for a, b in zip(pred_answer_vector, mean_pred_answer_vector))  # TODO: check it
        return weight
    
    def _get_inst_failure_mode(
        self,
        pred_answer_vector: List[int],
        majority_answer_vector: List[int],
        ips: float,
    ) -> InstFailureMode:
        if ips == 1.0:
            return InstFailureMode.CORRECT
        elif 0.0 < ips < 1.0:
            if sum(pred_answer_vector) == len(pred_answer_vector):
                return InstFailureMode.WRONG_NON_ZERO_CATCH_ALL
            else:
                return InstFailureMode.WRONG_NON_ZERO_CATCH_MULT
        elif ips == 0:
            if pred_answer_vector == majority_answer_vector:
                return InstFailureMode.WRONG_ZERO_MAJORITY_GUESS
            else:
                return InstFailureMode.WRONG_ZERO_OTHERS
        else:
            raise ValueError(f"Unexpected IPS value: {ips}")

    def _get_dist_failure_mode(
        self,
        pred_answer_vectors: List[List[int]],
        majority_answer_vector: List[int],
        dps: float,
    ) -> DistFailureMode:
        if abs(dps - 1.0) < 1e-12: # if dps == 1.0:
            return DistFailureMode.CORRECT
        elif 0.0 < dps < 1.0:
            return DistFailureMode.WRONG_NON_ZERO
        elif dps == 0:
            if len(set(tuple(pav) for pav in pred_answer_vectors)) == 1:
                if all(pav == majority_answer_vector for pav in pred_answer_vectors):
                    return DistFailureMode.WRONG_ZERO_MAJORITY_GUESS
                elif all(sum(pav) == len(pav) for pav in pred_answer_vectors):
                    return DistFailureMode.WRONG_ZERO_CATCH_ALL
                else:
                    return DistFailureMode.WRONG_ZERO_NO_DIFF
            else:
                return DistFailureMode.WRONG_ZERO_OTHERS
        else:
            raise ValueError(f"Unexpected DPS value: {dps}")

    def update(self, prediction: Prediction) -> Tuple[Prediction, List[float]]:
        answers = prediction.qa.answers
        pred_answers = prediction.pred_answers

        pred_answer_vectors = []
        for pred_answer in pred_answers:
            rubric_question_answer_dict = pred_answer.rubric_question_answer_dict
            pred_answer_vector = list(rubric_question_answer_dict.values())
            pred_answer_vectors.append(pred_answer_vector)
        mean_pred_answer_vector = [sum(x) / len(x) for x in zip(*pred_answer_vectors)]
        majority_answer_vector = list(Counter(tuple(x) for x in pred_answer_vectors).most_common(1)[0][0])

        weights = []
        for pred_answer_vector in pred_answer_vectors:
            weight = self._compute_weight(
                pred_answer_vector=pred_answer_vector,
                mean_pred_answer_vector=mean_pred_answer_vector,
            )
            weights.append(weight)

        sum_raw = sum(weights)
        if sum_raw <= 0.0:
            norm_weights = [1.0 / len(weights) for _ in weights]
        else:
            norm_weights = [w / sum_raw for w in weights]

        curr_scores = []
        answer_vectors = []
        for answer, pred_answer, norm_weight in zip(answers, pred_answers, norm_weights):
            assert answer.answer_id == pred_answer.answer_id

            rubric_question = answer.rubric_question
            rubric_question_answer_dict = pred_answer.rubric_question_answer_dict
            answer_vector = [1 if rubric_question == k else 0 for k in rubric_question_answer_dict.keys()]
            pred_answer_vector = list(rubric_question_answer_dict.values())
            ips = self._compute_iou(
                answer_vector=answer_vector,
                pred_answer_vector=pred_answer_vector,
            )
            dps = ips * norm_weight

            pred_answer.dps = dps
            pred_answer.inst_failure_mode = self._get_inst_failure_mode(
                pred_answer_vector=pred_answer_vector,
                majority_answer_vector=majority_answer_vector,
                ips=ips
            )
            answer_vectors.append(answer_vector)
            curr_scores.append(dps)

        dps = sum(curr_scores)
        prediction.dps = dps
        prediction.dist_failure_mode = self._get_dist_failure_mode(
            pred_answer_vectors=pred_answer_vectors,
            majority_answer_vector=majority_answer_vector,
            dps=dps,
        )
        self.scores.append(dps)

        return prediction, [dps]

    def compute(self) -> float:
        avg_dps = sum(score for score in self.scores) / len(self.scores) if self.scores else 0.0
        return avg_dps
