import abc
import json
from abc import ABC
from typing import Tuple, List, Union
import pandas as pd
from pandas import DataFrame

from compute_result.typing import Run, ProblemSpace, RunResult, Point, ResultRow
from problems.types import Suites

BUDGET_COL = "budget"
VALUE_COL = "value"
POINT_COL = "point"
ALGORITHM_NAME_COL = "algorithm_name"


class ResultStore(ABC):
    @abc.abstractmethod
    def storage_exists(self):
        raise NotImplementedError()

    @abc.abstractmethod
    def run_result(self, run: Run, problem_space: ProblemSpace) -> RunResult:
        raise NotImplementedError()

    @abc.abstractmethod
    def is_run_exists(self, run: Run, problem_space: ProblemSpace):
        raise NotImplementedError()

    @abc.abstractmethod
    def min_max_from_space(
        self, problem_space: ProblemSpace
    ) -> Tuple[float, float, Point, Point]:
        raise NotImplementedError()

    def common_spaces_on_suite_from_multiple_runs(
        self,
        runs: List[Run],
        suite: Suites = None,
        func_ids: List[int] = None,
        func_dims: List[int] = None,
    ) -> List[ProblemSpace]:
        all_runs_data_from_files = self.list_runs()
        runs_data = [
            (algorithm, run_name, run_suite, func_id, dim, func_instance)
            for (algorithm, run_name), (
                run_suite,
                func_id,
                dim,
                func_instance,
            ) in all_runs_data_from_files
            if (algorithm, run_name) in runs
            if (not suite) or run_suite == suite
        ]
        mapped_runs = {
            run: set(
                [
                    (run_suite, func_id, dim, func_instance)
                    for alg, run_name, run_suite, func_id, dim, func_instance in runs_data
                    if (alg, run_name) == run
                ]
            )
            for run in runs
        }
        common_problems = set(next(iter((mapped_runs.values()))))
        for run_problems in mapped_runs.values():
            common_problems = common_problems.intersection(run_problems)
        common_problems = list(common_problems)
        problems = [
            problem
            for problem in common_problems
            if func_ids is None or problem[1] in func_ids
            if func_dims is None or problem[2] in func_dims
        ]
        return problems

    @abc.abstractmethod
    def update_min(self, problem_space: ProblemSpace, min_value: float, min_point: Point):
        raise NotImplementedError()

    @abc.abstractmethod
    def update_max(self, problem_space: ProblemSpace, max_value: float, max_point: Point):
        raise NotImplementedError()

    @abc.abstractmethod
    def remove_run(self, run: Run, problem_space: ProblemSpace):
        raise NotImplementedError()

    @abc.abstractmethod
    def store_step(
        self,
        run: Run,
        problem_space: ProblemSpace,
        budget: int,
        value: float,
        point: Point,
        algorithm_name: str,
    ):
        raise NotImplementedError()

    @abc.abstractmethod
    def store_run(self, run: Run, problem_space: ProblemSpace, run_result: RunResult):
        raise NotImplementedError()

    @abc.abstractmethod
    def list_min_max_problems(self) -> List[ProblemSpace]:
        raise NotImplementedError()

    @abc.abstractmethod
    def list_runs(self) -> List[Tuple[Run, ProblemSpace]]:
        raise NotImplementedError()

    def list_problems(self) -> List[ProblemSpace]:
        results_problems = [problem for _, problem in self.list_runs()]
        min_max_problems = self.list_min_max_problems()
        return min_max_problems + list(set(results_problems) - set(min_max_problems))

    @abc.abstractmethod
    def rename_run(self, run: Run, new_run_name: str):
        raise NotImplementedError()

    @abc.abstractmethod
    def all_algorithms_in_run(self, run: Run, problems: List[ProblemSpace]) -> List[str]:
        raise NotImplementedError()

    @staticmethod
    def _filter_multiple_budget(
        run_results: RunResult, return_df: bool = False
    ) -> Union[RunResult, pd.DataFrame]:
        result_df = pd.DataFrame(
            run_results, columns=[BUDGET_COL, VALUE_COL, POINT_COL, ALGORITHM_NAME_COL]
        )
        unique_result_df = (
            result_df.sort_values(by=[VALUE_COL], ascending=False)
            .sort_values(by=[BUDGET_COL], ascending=False)
            .drop_duplicates(subset=[VALUE_COL])
            .drop_duplicates(subset=[BUDGET_COL])
        )
        last_row = result_df.sort_values(
            by=[VALUE_COL, BUDGET_COL], ascending=[True, False]
        ).iloc[0]
        last_row = pd.DataFrame([last_row], columns=result_df.columns)
        unique_result_df = pd.concat(
            [unique_result_df, last_row], ignore_index=True, sort=False
        )
        unique_result_df = unique_result_df.sort_values(
            [VALUE_COL, BUDGET_COL], ascending=[True, False]
        )

        if not return_df:
            return list(unique_result_df.to_records())
        unique_result_df[POINT_COL] = unique_result_df[POINT_COL].apply(json.dumps)
        return unique_result_df

    def problem_final_value_for_run(self, run: Run, problem: ProblemSpace) -> ResultRow:
        return self.run_result(run, problem)[-1]

    def normalize_final_value_for_run(self, run: Run, problem: ProblemSpace) -> ResultRow:
        final_value = self.problem_final_value_for_run(run, problem)
        min_value, max_value, _, _ = self.min_max_from_space(problem)
        return (
            final_value[0],
            (final_value[1] - min_value) / (max_value - min_value),
            final_value[2],
            final_value[3],
        )

    @abc.abstractmethod
    def store_metric(self, metric_name: str, data: DataFrame):
        raise NotImplementedError()

    @abc.abstractmethod
    def get_metric(self, metric_name: str) -> DataFrame:
        raise NotImplementedError()

    @abc.abstractmethod
    def get_metrics_types(self, metric_initial: str) -> List[str]:
        raise NotImplementedError()
