from typing import (
    List,
    Dict,
    Optional,
)

from enum import Enum
from pydantic import BaseModel

from src.schema import QA


class InstFailureMode(str, Enum):
    CORRECT = "correct"
    WRONG_NON_ZERO_CATCH_ALL = "wrong_non_zero_catch_all"
    WRONG_NON_ZERO_CATCH_MULT = "wrong_non_zero_catch_mult"
    WRONG_ZERO_MAJORITY_GUESS = "wrong_zero_majority_guess"
    WRONG_ZERO_OTHERS = "wrong_zero_others"


class DistFailureMode(str, Enum):
    CORRECT = "correct"
    WRONG_NON_ZERO = "wrong_non_zero"
    WRONG_ZERO_MAJORITY_GUESS = "wrong_zero_majority_guess"
    WRONG_ZERO_CATCH_ALL = "wrong_zero_catch_all"
    WRONG_ZERO_NO_DIFF = "wrong_zero_no_diff"
    WRONG_ZERO_OTHERS = "wrong_zero_others"


class PredictedAnswer(BaseModel):
    answer_id: str
    ret_document_ids: Optional[List[str]] = None  # ret
    ret_document_contents: Optional[List[str]] = None  # ret
    ret_document_scores: Optional[List[float]] = None  # ret
    ret_memory_ids: Optional[List[str]] = None  # ret
    ret_memory_contents: Optional[List[str]] = None  # ret
    ret_memory_scores: Optional[List[float]] = None  # ret
    gen_sub_queries: Optional[List[str]] = None  # ret
    gen_pers_graph: Optional[Dict[str, List[str]]] = None  # ret
    raw_answer: Optional[str] = None  # gen
    rubric_question_answer_dict: Optional[Dict[str, int]] = None  # eval
    recall_at_k: Optional[float] = None  # eval
    em: Optional[float] = None  # eval
    ips: Optional[float] = None  # eval
    dps: Optional[float] = None  # eval
    inst_failure_mode: Optional[InstFailureMode] = None  # eval
    metadata: Optional[Dict] = None


class Prediction(BaseModel):
    qa: QA
    pred_answers: List[PredictedAnswer]
    recall_at_k: Optional[float] = None
    em: Optional[float] = None
    ips: Optional[float] = None
    dps: Optional[float] = None
    dist_failure_mode: Optional[DistFailureMode] = None
    metadata: Optional[Dict] = None


class EvaluationReport(BaseModel):
    predictions: List[Prediction]
    recall_at_k: Optional[float] = None
    em: Optional[float] = None
    ips: Optional[float] = None
    dps: Optional[float] = None
    inst_failure_mode_counts: Optional[Dict[InstFailureMode, int]] = None
    dist_failure_mode_counts: Optional[Dict[DistFailureMode, int]] = None
    metadata: Optional[Dict] = None


class EvaluationReportGroup(BaseModel):
    total_eval_report: EvaluationReport
    medical_eval_report: EvaluationReport
    legal_eval_report: EvaluationReport
    casual_eval_report: EvaluationReport
    metadata: Optional[Dict] = None
