__all__ = ["Solution", "TestResult", "ProblemResult", "RunResult"]

import json

from src.datatypes.enums import Language
from src.tools.dataloaders import get_test_cases_for_single_problem


class Solution:
    @classmethod
    def from_response(cls, problem_id, response, language):
        if response.get("test_cases") is not None:
            test_cases = response["test_cases"]
        elif (
            response.get("metadata") is not None
            and response["metadata"].get("test_cases") is not None
        ):
            test_cases = response["metadata"]["test_cases"]
        else:
            test_cases = get_test_cases_for_single_problem(problem_id)

        return cls(
            problem_id,
            response["metadata"]["question"],
            response["solution"],
            language,
            test_cases,
            response["metadata"]["difficulty"],
        )

    @classmethod
    def from_cache_file(cls, path):
        with open(path) as f:
            solution = json.load(f)

        return cls(
            solution.get("question_id"),
            solution.get("question"),
            solution.get("solution"),
            Language.from_code(solution.get("language"))
            if solution.get("language") is not None
            else solution.get("language"),
            solution.get("test_cases"),
            solution.get("difficulty"),
        )

    @classmethod
    def from_executor_output(cls, output):
        return cls(
            output["question_id"],
            output["question"],
            output["solution"],
            output["language"],
            output["test_cases"],
            output["difficulty"],
        )

    @classmethod
    def no_solution(cls, problem_id):
        return cls(problem_id, None, None, None, None, None)

    def __init__(
        self, question_id, question, solution, language, test_cases, difficulty
    ):
        self.question_id = question_id
        self.question = question
        self.solution = solution
        self.language = language
        self.test_cases = test_cases
        self.difficulty = difficulty

        if self.test_cases:
            if "correct" in self.test_cases[0]:
                self.correct = all([i["correct"] for i in self.test_cases])
                self.already_processed = True
            else:
                self.correct = None
                self.already_processed = False
        else:
            self.correct = False
            self.already_processed = True

    def __str__(self):
        return f"Question ID: {self.question_id}\nQuestion: {self.question}\nSolution: {self.solution}\nLanguage: {self.language}\nTest Cases: {self.test_cases}\nDifficulty: {self.difficulty}"

    def __repr__(self):
        return self.__str__()

    def to_json(self):
        return {
            "question_id": self.question_id,
            "question": self.question,
            "solution": self.solution,
            "language": self.language.code if self.language else None,
            "test_cases": self.test_cases,
            "difficulty": self.difficulty,
            "correct": self.correct,
        }


class TestResult:
    def __init__(self, input, expected_output, actual_output):
        self.input = input
        self.expected_output = expected_output
        self.actual_output = actual_output
        self.correct = self.check_equal(expected_output, actual_output)

    def check_equal(self, expected, actual):
        if isinstance(expected, float) and isinstance(actual, float):
            return abs(expected - actual) < 1e-6
        if isinstance(expected, list) and isinstance(actual, list):
            return all([self.check_equal(v1, v2) for v1, v2 in zip(expected, actual)])
        if isinstance(expected, str):
            expected = expected.strip()
        if isinstance(actual, str):
            actual = actual.strip()
        return str(expected) == str(actual)

    def to_json(self):
        return {
            "input": self.input,
            "expected": self.expected_output,
            "actual": self.actual_output,
            "correct": self.correct,
        }

    def __str__(self):
        return f"Input: {self.input}\nExpected Output: {self.expected_output}\nActual Output: {self.actual_output}\nCorrect: {self.correct}"

    def __repr__(self):
        return self.__str__()


class ProblemResult:
    def __init__(self, solution, test_results, correct=None):
        self.solution = solution
        self.test_results = test_results
        if correct is not None:
            self.correct = correct
        elif test_results == "ERROR":
            self.correct = False
        else:
            self.correct = all([test_result.correct for test_result in test_results])

    def __str__(self):
        if self.test_results == "ERROR":
            return f"Problem ID: {self.solution.question_id}\nError: Could not run solution"
        correct_tests = sum(
            [1 for test_result in self.test_results if test_result.correct]
        )
        return f"Problem ID: {self.solution.question_id}\nSolution: {self.solution.solution}\nTest Results:\n\tCorrect: {correct_tests}\n\tTotal: {len(self.test_results)}\n\tAccuracy: {correct_tests/len(self.test_results)}\nOverall Correct: {self.correct}"

    def to_solution(self):
        self.solution.test_cases = [i.to_json() for i in self.test_results]
        self.solution.correct = self.correct
        return self.solution

    def __repr__(self):
        return self.__str__()


class RunResult:
    @classmethod
    def from_executor_output(cls, results_list):
        problem_results = {}
        for problem_result in results_list:
            test_results = []
            for test_result in problem_result["test_cases"]:
                test_results.append(
                    TestResult(
                        test_result["input"], test_result["output"], test_result["pred"]
                    )
                )
        problem_results[problem_result["question_id"]] = ProblemResult(
            Solution.from_executor_output(problem_result), test_results
        )
        return cls(problem_results)

    def __init__(self, problem_results):
        self.problem_results = problem_results
        self.num_correct = sum(
            [problem_result.correct for problem_result in problem_results.values()]
        )
        self.accuracy = float(self.num_correct) / max(1, len(problem_results))

    def __str__(self):
        return f"Number of Problems: {len(self.problem_results)}\nNumber Correct: {self.num_correct}\nAccuracy: {self.accuracy}"

    def __repr__(self):
        return self.__str__()
