import json
import os
import re
from pathlib import Path
from typing import Dict, Any, List, Optional, Callable, Tuple


class Evaluation:
    def __init__(self, value: str, author: str):
        self.value = value
        self.author = author

    def to_dict(self) -> Dict[str, Any]:
        return {
            "value": self.value,
            "author": self.author,
        }

    @staticmethod
    def from_dict(evaluation_dict: Dict[str, Any]) -> "Evaluation":
        return Evaluation(
            evaluation_dict["value"],
            evaluation_dict["author"],
        )


class BongardSolution:
    def __init__(
        self,
        problem_id: int,
        answer: str,
        explanation: str = "",
        evaluations: List[Evaluation] = [],
        metadata: Optional[Dict[str, Any]] = None,
    ):
        self.problem_id = problem_id
        self.answer = answer
        self.explanation = explanation
        self.metadata = {} if metadata is None else metadata
        self.evaluations = evaluations


class BongardResolveAttempt:
    def __init__(self, model_name: str = ""):
        self.__model_name: str = model_name
        self.__solutions: Dict[int, BongardSolution] = {}

    @staticmethod
    def from_file(input_file: str) -> "BongardResolveAttempt":
        attempt = BongardResolveAttempt()
        attempt.load(input_file)
        return attempt

    def add_solution(
        self,
        problem_id: int,
        answer: str,
        explanation: str = "",
        metadata: Optional[Dict[str, Any]] = None,
    ):
        self.__solutions[problem_id] = BongardSolution(
            problem_id, answer, explanation=explanation, metadata=metadata
        )

    def evaluate_solution(
        self,
        problem_id: int,
        evaluation: str,
        author: str = "",
        reevaluate: bool = False,
    ):
        evals = self.__solutions[problem_id].evaluations
        if reevaluate and any([eval.author == author for eval in evals]):
            for eval in evals:
                if eval.author == author:
                    eval.value = evaluation
        else:
            evals.append(Evaluation(evaluation, author))

    def get_solutions(self) -> Dict[int, BongardSolution]:
        return self.__solutions.copy()

    def has_solution(self, problem_id: int) -> bool:
        return problem_id in self.__solutions

    def count_self_evaluated_solutions(
        self, evaluation_value_predicate: Callable[[str], bool]
    ):
        correct_answers = 0
        for problem_id, solution in self.get_solutions().items():
            self_evals = [
                eval.value
                for eval in solution.evaluations
                if eval.author == self.__model_name
            ]
            if len(self_evals) == 0:
                raise ValueError(
                    f"Problem {problem_id} was not evaluated by {self.__model_name}"
                )
            if len(self_evals) > 1:
                raise ValueError(
                    f"Problem {problem_id} was evaluated by {self.__model_name} {len(self_evals)} times"
                )
            if evaluation_value_predicate(self_evals[0]):
                correct_answers += 1
        return correct_answers

    def get_self_evaluations(self) -> Dict[int, str]:
        return self.get_evaluations(self.__model_name)

    def get_evaluations(
        self, author: str, ignore_missing: bool = False
    ) -> Dict[int, str]:
        evaluations = dict()
        for problem_id, solution in self.get_solutions().items():
            self_evals = [e.value for e in solution.evaluations if e.author == author]
            if len(self_evals) == 0 and ignore_missing:
                evaluations[problem_id] = None
            elif len(self_evals) == 0 and not ignore_missing:
                raise ValueError(f"Problem {problem_id} was not evaluated by {author}")
            elif len(self_evals) > 1:
                raise ValueError(
                    f"Problem {problem_id} was evaluated by {author} {len(self_evals)} times"
                )
            else:
                evaluations[problem_id] = self_evals[0]
        return evaluations

    def save(self, output_file: str):
        attempt_dict = {
            "model_name": self.__model_name,
            "solutions": [
                {
                    "problem_id": solution.problem_id,
                    "answer": solution.answer,
                    "explanation": solution.explanation,
                    "evaluations": [
                        evaluation.to_dict() for evaluation in solution.evaluations
                    ],
                }
                | solution.metadata
                for solution in dict(sorted(self.__solutions.items())).values()
            ],
        }

        path = Path(output_file)
        os.makedirs(path.parent.absolute(), exist_ok=True)
        with open(output_file, "w") as json_file:
            json.dump(attempt_dict, json_file, indent=4)

    def load(self, input_file: str):
        with open(input_file, "r") as json_file:
            attempt_dict = json.load(json_file)

        self.__model_name = attempt_dict["model_name"]
        self.__solutions = {
            solution["problem_id"]: BongardSolution(
                solution["problem_id"],
                solution["answer"],
                solution.get("explanation", ""),
                [
                    Evaluation.from_dict(evaluation)
                    for evaluation in solution.get("evaluations", [])
                ],
                {
                    k: v
                    for k, v in solution.items()
                    if k not in {"problem_id", "answer", "evaluations", "explanation"}
                },
            )
            for solution in attempt_dict["solutions"]
        }


class ClassificationSolution:
    def __init__(
        self,
        problem_id: int,
        answer: str,
        query_file: str,
        explanation: str = "",
        explanation_evaluation: str = "",
        concept: str = "",
        concept_evaluation: str = "",
        evaluations: Optional[List[Evaluation]] = None,
    ):
        if evaluations is None:
            evaluations = []
        self.problem_id = problem_id
        self.answer = answer
        self.query_file = query_file
        self.explanation = explanation
        self.explanation_evaluation = explanation_evaluation
        self.concept = concept
        self.concept_evaluation = concept_evaluation
        self.evaluations = evaluations

    @property
    def evaluation(self) -> str:
        self_evals = [
            eval.value for eval in self.evaluations if eval.author == "oracle"
        ]
        if len(self_evals) == 0:
            raise ValueError(
                f"Query file {self.query_file} was not evaluated by 'oracle'"
            )
        if len(self_evals) > 1:
            raise ValueError(
                f"Query file {self.query_file} was evaluated by 'oracle' {len(self_evals)} times"
            )
        return self_evals[0]


class ClassificationAttempt:
    def __init__(self, model_name: str = ""):
        self.__model_name: str = model_name
        self.__solutions: Dict[str, ClassificationSolution] = {}

    @staticmethod
    def from_file(input_file: str) -> "ClassificationAttempt":
        attempt = ClassificationAttempt()
        attempt.load(input_file)
        return attempt

    def add_solution(
        self,
        problem_id: int,
        answer: str,
        query_file: str,
        explanation: str = "",
        concept: str = "",
    ):
        self.__solutions[query_file] = ClassificationSolution(
            problem_id,
            answer,
            query_file,
            explanation,
            concept=concept,
        )

    def evaluate_solution(
        self,
        query_file: str,
        evaluation: str,
        author: str = "oracle",
        reevaluate: bool = False,
    ):
        evals = self.__solutions[query_file].evaluations
        if reevaluate and any([eval.author == author for eval in evals]):
            for eval in evals:
                if eval.author == author:
                    eval.value = evaluation
        else:
            evals.append(Evaluation(evaluation, author))

    def evaluate_concept(self, query_file: str, concept_evaluation: str):
        self.__solutions[query_file].concept_evaluation = concept_evaluation

    def evaluate_explanation(self, query_file: str, explanation: str):
        self.__solutions[query_file].explanation_evaluation = explanation

    def get_solutions(self) -> Dict[str, ClassificationSolution]:
        return self.__solutions.copy()

    def has_solution(self, query_file: str) -> bool:
        return query_file in self.__solutions

    def get_evaluations(self) -> Dict[Tuple[int, str], str]:
        evaluations = dict()
        for query_file, solution in self.get_solutions().items():
            match = re.search(r"(\d+)/(left|right)", query_file)
            problem_id, side = match.groups()
            evaluations[(int(problem_id), side)] = solution.evaluation
        return evaluations

    def save(self, output_file: str):
        attempt_dict = {
            "model_name": self.__model_name,
            "solutions": [
                {
                    "problem_id": solution.problem_id,
                    "answer": solution.answer,
                    "query_file": solution.query_file,
                    "explanation": solution.explanation,
                    "explanation_evaluation": solution.explanation_evaluation,
                    "concept": solution.concept,
                    "concept_evaluation": solution.concept_evaluation,
                    "evaluations": [
                        evaluation.to_dict() for evaluation in solution.evaluations
                    ],
                }
                for solution in self.__solutions.values()
            ],
        }

        path = Path(output_file)
        os.makedirs(path.parent.absolute(), exist_ok=True)
        with open(output_file, "w") as json_file:
            json.dump(attempt_dict, json_file, indent=4)

    def load(self, input_file: str):
        with open(input_file, "r") as json_file:
            attempt_dict = json.load(json_file)

        self.__model_name = attempt_dict["model_name"]
        self.__solutions = {
            solution["query_file"]: ClassificationSolution(
                solution["problem_id"],
                solution["answer"],
                solution["query_file"],
                solution.get("explanation", ""),
                solution.get("explanation_evaluation", ""),
                solution.get("concept", ""),
                solution.get("concept_evaluation", ""),
                [
                    Evaluation.from_dict(evaluation)
                    for evaluation in solution.get("evaluations", [])
                ],
            )
            for solution in attempt_dict["solutions"]
        }
