from src.evaluation.model import get_resolved_problems
from typing import Dict, List, Literal, Optional, Set


class Solver:
    def __init__(self, name, label, color):
        self.name = name
        self.label = label
        self.color = color


EVALUATION_AUTHORS = [
    "gpt-4o_STRICT_LOGIC_PROMPT",
    "gemini-1.5-pro_STRICT_LOGIC_PROMPT",
    "claude-3-5-sonnet-20240620_STRICT_LOGIC_PROMPT",
    "gpt-4-turbo_STRICT_LOGIC_PROMPT",
]


SOLVERS = [
    Solver("gpt-4-turbo", "GPT-4 Turbo", "dimgrey"),
    Solver("gpt-4o", "GPT-4o", "black"),
    Solver("gemini-1.5-pro", "Gemini 1.5 Pro", "blue"),
    Solver("claude-3-5-sonnet-20240620", "Claude 3.5 Sonnet", "orange"),
    Solver("OpenGVLab/InternVL2-8B", "InternVL2 8B", "green"),
    Solver("llava-hf/llava-v1.6-mistral-7b-hf", "LlaVa 1.6 Mistral 7B", "red"),
    Solver("microsoft/Phi-3.5-vision-instruct", "Phi 3.5V", "purple"),
    Solver("mistralai/Pixtral-12B-2409", "Pixtral 12B", "brown"),
]

TECHNIQUES = [
    "contrastive",
    "contrastive-direct",
    "iterative-contrastive",
    "descriptive",
    "descriptive-direct",
    "iterative",
    "direct",
]

CONTRASTIVE_TECHNIQUES = [
    "contrastive",
    "contrastive-direct",
    "iterative-contrastive",
]

DESCRIPTIVE_TECHNIQUES = [
    "descriptive",
    "descriptive-direct",
    "iterative",
]

DIRECT_TECHNIQUES = [
    "direct",
]


def get_all_solver_names():
    return [solver.name for solver in SOLVERS]


class BongardResults:
    def __init__(
        self,
        technique_to_solved_problems: Dict[str, Set[int]] = {},
    ):
        self.__technique_to_solved_problems = technique_to_solved_problems

    @staticmethod
    def for_solver(
        main_results_path: str,
        solver_name: str,
        dataset_alias: str,
        evaluation_authors: Optional[List[str]] = EVALUATION_AUTHORS,
        techniques: Optional[List[str]] = TECHNIQUES,
        evaluation_type: Literal["voting", "all models", "any model"] = "voting",
    ):
        technique_to_solved_problems: Dict[str, Set[int]] = {}

        for technique in techniques:
            path = f"{main_results_path}/{dataset_alias}_generation_prompting-{technique}/{solver_name}_answers.json"

            solved_problems = get_resolved_problems(
                path,
                evaluation_authors,
                evaluation_type,
            )

            technique_to_solved_problems[technique] = solved_problems

        return BongardResults(technique_to_solved_problems)

    @staticmethod
    def for_(
        main_results_path: str,
        solver_names: List[str],
        dataset_alias: str,
        evaluation_authors: Optional[List[str]] = EVALUATION_AUTHORS,
        techniques: Optional[List[str]] = TECHNIQUES,
        evaluation_type: Literal["voting", "all models", "any model"] = "voting",
    ) -> "BongardResults":
        bongard_results = BongardResults()

        for solver_name in solver_names:
            bongard_results = bongard_results.union(
                BongardResults.for_solver(
                    main_results_path,
                    solver_name,
                    dataset_alias,
                    evaluation_authors,
                    techniques,
                    evaluation_type,
                )
            )

        return bongard_results

    def union(self, other: "BongardResults"):
        return BongardResults(
            {
                technique: self.__technique_to_solved_problems.get(technique, set())
                | other.__technique_to_solved_problems.get(technique, set())
                for technique in set(
                    self.__technique_to_solved_problems.keys()
                    | other.__technique_to_solved_problems.keys()
                )
            }
        )

    def get_solved_problems(self, technique: str):
        return self.__technique_to_solved_problems[technique]

    def get_all_solved_by_any_technique(self):
        return set.union(*self.__technique_to_solved_problems.values())

    def get_all_solved_by_contrastive_techniques(self):
        return set.union(
            *[
                self.__technique_to_solved_problems[technique]
                for technique in CONTRASTIVE_TECHNIQUES
            ]
        )

    def get_all_solved_by_descriptive_techniques(self):
        return set.union(
            *[
                self.__technique_to_solved_problems[technique]
                for technique in DESCRIPTIVE_TECHNIQUES
            ]
        )

    def get_all_solved_by_direct_techniques(self):
        return set.union(
            *[
                self.__technique_to_solved_problems[technique]
                for technique in DIRECT_TECHNIQUES
            ]
        )
