import itertools
import logging
import math
import operator
from logging import Logger
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from matplotlib import pyplot as plt
from matplotlib.backends.backend_template import FigureCanvas
from matplotlib.figure import Figure
from scipy.stats import entropy
from sklearn.decomposition import PCA
from torch import Tensor

from algorithms.space.callable_function import CallableSpace
from compute_result.output_manager.base import OutputManager, AxisTypes
from compute_result.result_store.base import ResultStore
from compute_result.typing import (
    Run,
    ProblemSpace,
    BudgetRun,
    RunResult,
    StatisticalAnalysisOptions,
    Point,
    ResultRow,
    MetricManipulation,
    AxisAnalysis,
)
from compute_result.utils import (
    measure_statistic_result_over_multiple_runs,
    problems_with_the_same_func,
    problems_with_the_same_dim,
    max_budget_of_runs,
    from_mapped_result_to_steps,
    get_index_ranges,
    normalize_result,
    METRIC_MANIPULATION,
    STATISTICAL_ANALYSIS_MAPPING,
    points_from_run,
    how_much_budget_to_reach_percentile,
    AXIS_ANALYSIS,
    np_array_from_results,
    rolling_average,
    np_array_from_res_split_by_dim,
    DEFAULT_COLOR_MAP,
    full_steps_for_budget_run,
    get_label_color_map,
    create_plot_with_statistics,
)
from handlers.drawers.utils import create_grid_points
from problems.benchmarks import BENCHMARK_MAPPER
from problems.suite import SUITES_MAPPING
from problems.types import Suites, Benchmarks
from problems.utils import BUDGET
from utils.algorithms_data import Algorithms
from utils.logger import create_file_log_path
from utils.python import distance_between_points, distance_between_many_tensors

plt.style.use("ggplot")


class ResultManager:
    def __init__(
        self,
        result_storage: ResultStore,
        output_manager: OutputManager,
        logger: Logger = None,
        colormap=None,
    ):
        self.result_storage = result_storage
        self.output_manager = output_manager
        self.logger = logger or logging.getLogger(__name__)
        self.colormap = colormap or DEFAULT_COLOR_MAP

    def _distance_from_best_point_for_run(
        self, run, problem_space: ProblemSpace
    ) -> BudgetRun:
        self.logger.info(
            f"Getting data for run {run} in {problem_space} to calculate the distance from best point at each stage"
        )
        run_result = self.result_storage.run_result(run, problem_space)
        _, _, min_point, *_ = self.result_storage.min_max_from_space(problem_space)
        return [
            (
                budget,
                distance_between_points(min_point, curr_point),
            )
            for budget, _, curr_point, _ in run_result
        ]

    def _run_normalize_result(self, run: Run, problem_space: ProblemSpace) -> RunResult:
        run_result = self.result_storage.run_result(run, problem_space)
        algorithm_min, algorithm_max, *_ = self.result_storage.min_max_from_space(
            problem_space
        )
        if algorithm_min == algorithm_max:
            self.logger.warning(f"No new point found in {problem_space}")
            return []
        return [
            (
                budget,
                normalize_result(func_value, algorithm_min, algorithm_max),
                point,
                alg_name,
            )
            for budget, func_value, point, alg_name in run_result
        ]

    def _normalize_last_step(self, run: Run, problem: ProblemSpace) -> ResultRow:
        algorithm_min, algorithm_max, *_ = self.result_storage.min_max_from_space(
            problem
        )
        (
            budget,
            value,
            point,
            alg_name,
        ) = self.result_storage.problem_final_value_for_run(run, problem)
        return (
            budget,
            normalize_result(value, algorithm_min, algorithm_max),
            point,
            alg_name,
        )

    def _print_normalize_run_on_problems(
        self,
        run: Run,
        problems: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions,
    ):
        self.logger.info(f"Computing {run}")
        normalized_results = [
            [(budget, norm_value) for budget, norm_value, *_ in normalized_run]
            for problem_space in problems
            if (normalized_run := self._run_normalize_result(run, problem_space))
        ]
        self.logger.info(f"Result are normalized")
        mean_results = measure_statistic_result_over_multiple_runs(
            normalized_results, analysis
        )
        return mean_results

    def _full_results_runs_on_problems(
        self, runs: List[Run], problems: List[ProblemSpace]
    ):
        full_results = []
        for run in runs:
            self.logger.info(f"Computing {run}")
            normalized_results = [
                [(budget, norm_value) for budget, norm_value, *_ in normalized_run]
                for problem_space in problems
                if (normalized_run := self._run_normalize_result(run, problem_space))
            ]
            self.logger.info(f"Result are normalized")
            full_results.append(normalized_results)
        max_budget = max([max_budget_of_runs(res) for res in full_results])
        return [full_steps_for_budget_run(res, max_budget) for res in full_results]

    def _compare_normalized_runs_on_problems(
        self,
        runs: List[Run],
        problems: List[ProblemSpace],
        axis_type: AxisTypes,
        analysis: StatisticalAnalysisOptions,
        graph_name: str,
        plot_names: List[str],
    ):
        plot_names = plot_names or [""] * len(runs)
        normalized_results = [
            (
                run,
                self._print_normalize_run_on_problems(run, problems, analysis),
                plot_name,
            )
            for run, plot_name in zip(runs, plot_names)
        ]
        max_len_results = max(len(results) for _, results, _ in normalized_results)
        padded_normalized_results = [
            (run, results + [results[-1]] * (max_len_results - len(results)), plot_name)
            for run, results, plot_name in normalized_results
        ]

        for run, results, plot_name in padded_normalized_results:
            self.output_manager.log_alg_progression(
                run,
                results,
                axis_type,
                graph_name=graph_name,
                plot_name=plot_name,
            )

    def full_graph_compare(
        self,
        runs: List[Run],
        suite: Suites = None,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Comparing runs {runs} on {suite}")
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self.logger.info(f"Comparing runs {runs} on {len(common_spaces)}")
        full_steps_results = self._full_results_runs_on_problems(runs, common_spaces)
        full_step_results_np = np.array(full_steps_results)
        mean_results = np.mean(full_step_results_np, axis=1)
        percentile_25 = np.percentile(full_step_results_np, 25, axis=1)
        percentile_75 = np.percentile(full_step_results_np, 75, axis=1)

        budget_len = mean_results.shape[-1]
        x = np.arange(budget_len)
        fig, ax = create_plot_with_statistics(
            x, mean_results, percentile_25, percentile_75, plot_names, self.colormap
        )
        ax.set_xlabel("Budget")
        ax.set_ylabel("Value")
        ax.legend()
        ax.grid(True)
        ax.set_yscale("log")
        self.output_manager.print_image(
            runs[0], fig, 0, graph_name, "compare_algorithms"
        )

    def compare_runs(
        self,
        runs: List[Run],
        suite: Suites = None,
        axis_type: AxisTypes = AxisTypes.NORMAL,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Comparing runs {runs} on {suite}")
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self.logger.info(f"Comparing runs {runs} on {len(common_spaces)}")
        self._compare_normalized_runs_on_problems(
            runs,
            common_spaces,
            axis_type,
            analysis,
            f"{graph_name} on {len(common_spaces)} spaces",
            plot_names,
        )
        self.output_manager.finish()

    def update_from_new_result(self, new_result_storage: ResultStore):
        if not new_result_storage.storage_exists():
            self.logger.error(f"Storage {new_result_storage} does not exists")
            return
        self.logger.info(f"Updating {self.result_storage} from {new_result_storage}")
        for problem in new_result_storage.list_min_max_problems():
            (
                min_value,
                max_value,
                min_point,
                max_point,
            ) = new_result_storage.min_max_from_space(problem)
            self.result_storage.update_min(problem, min_value, min_point)
            self.result_storage.update_max(problem, max_value, max_point)
        for run, problem in new_result_storage.list_runs():
            run_result = new_result_storage.run_result(run, problem)
            if not run_result:
                self.logger.warning(
                    f"Problem {problem} on run {run} was empty, continue"
                )
                continue
            self.result_storage.store_run(run, problem, run_result)

    def print_run(self, runs: List[Run], suite: Suites = None, graph_name: str = ""):
        self.logger.info(f"Printing all results of {runs}")
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        for run in runs:
            self.logger.info(f"Starting to print the run {run}")
            for problem in common_spaces:
                self.logger.info(f"printing {problem}")
                run_result = [
                    (budget, value)
                    for budget, value, point, *_ in self.result_storage.run_result(
                        run, problem
                    )
                ]

                name = f"{problem}_{graph_name}" if graph_name else str(problem)
                self.output_manager.log_data_progression(
                    run, run_result, graph_name=name
                )

    def compare_run_by_func(
        self,
        runs: List[Run],
        suite: Suites,
        axis_type: AxisTypes = AxisTypes.NORMAL,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = "",
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Comparing runs {runs} on {suite} by function")
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self.logger.info(f"Running on {len(common_spaces)} problems")
        run_by_function = {
            (run_suite, func_id): problems_with_the_same_func(
                run_suite, func_id, common_spaces
            )
            for run_suite, func_id, *_ in common_spaces
        }

        for func, problems in run_by_function.items():
            suite, func_id = func
            env = SUITES_MAPPING[suite](func_id, 2, 1, None)
            name = f"compare {func} {repr(env)}"
            if graph_name:
                name += f" {graph_name}"
            name += f" on {len(problems)} spaces"

            self._compare_normalized_runs_on_problems(
                runs, problems, axis_type, analysis, name, plot_names
            )
            self.output_manager.finish()

    def compare_run_by_dim(
        self,
        runs: List[Run],
        suite: Suites,
        axis_type: AxisTypes = AxisTypes.NORMAL,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = "",
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Comparing runs {runs} on {suite} by dim")
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self.logger.info(f"Running on {len(common_spaces)} problems")
        run_by_function = {
            dim: problems_with_the_same_dim(dim, common_spaces)
            for _, _, dim, *_ in common_spaces
        }
        problem_per_dim = {
            dim: len(problems) for dim, problems in run_by_function.items()
        }
        self.logger.info(f"By dim {problem_per_dim}")

        for dim, problems in run_by_function.items():
            self.logger.info(f"Creating for {dim} on {len(problems)}")
            name = f"compare dim {dim}"
            if graph_name:
                name += f" {graph_name}"
            name += f" on {len(problems)} spaces"

            self._compare_normalized_runs_on_problems(
                runs, problems, axis_type, analysis, name, plot_names
            )
            self.output_manager.finish()

    def print_problem_where_run_is_better(self, tested_run: Run, other_runs: Run):
        pass

    def rename_run_name(self, run: Run, new_run_name: str):
        self.result_storage.rename_run(run, new_run_name)

    def print_distance_from_best(self, run: Run, graph_name: str):
        self.logger.info(f"Getting all problems from {run}")
        problems_in_runs = [
            problem
            for curr_run, problem in self.result_storage.list_runs()
            if curr_run == run
        ]
        for problem in problems_in_runs:
            self.logger.info(
                f"Extracting best point of the problem {problem} to check distance"
            )
            distance_from_best_by_budget = self._distance_from_best_point_for_run(
                run, problem
            )

            self.output_manager.log_data_progression(
                run, distance_from_best_by_budget, graph_name=f"{graph_name} {problem}"
            )

    def compare_distance_from_best_by_func(
        self,
        runs: List[Run],
        suite: Suites,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = "",
    ):
        self.logger.info(
            f"Comparing distance from best in {runs} on {suite} by function"
        )
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self.logger.info(f"Running on {len(common_spaces)} problems")
        run_by_function = {
            (run_suite, func_id): problems_with_the_same_func(
                run_suite, func_id, common_spaces
            )
            for run_suite, func_id, *_ in common_spaces
        }

        for func, problems in run_by_function.items():
            name = f"compare distance from best {func}"
            if graph_name:
                name += f" {graph_name}"
            name += f" on {len(problems)} spaces"

            self._compare_distance_from_best_on_runs(runs, problems, analysis, name)

    def compare_distance_from_best(
        self,
        runs: List[Run],
        suite: Suites = None,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = None,
    ):
        self.logger.info(
            f"Comparing distance between best and current point in runs {runs} on {suite}"
        )
        common_spaces = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        self._compare_distance_from_best_on_runs(
            runs,
            common_spaces,
            analysis,
            f"{graph_name} on {len(common_spaces)} spaces",
        )

    def _compare_distance_from_best_on_runs(
        self,
        runs: List[Run],
        common_spaces: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions,
        graph_name: str,
    ):
        for run in runs:
            self._print_distance_from_best(run, common_spaces, analysis, graph_name)

    def _print_distance_from_best(
        self,
        run: Run,
        spaces: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions,
        graph_name: str,
    ):
        self.logger.info(f"Computing distance from best point for {run}")
        distance_from_best_point_bu_budget = [
            normalized_run
            for problem_space in spaces
            if (
                normalized_run := self._distance_from_best_point_for_run(
                    run, problem_space
                )
            )
        ]
        mean_results = measure_statistic_result_over_multiple_runs(
            distance_from_best_point_bu_budget, analysis
        )
        self.output_manager.log_alg_progression(
            run, mean_results, graph_name=graph_name
        )

    def missing_parts_of_benchmark_run(
        self,
        run: Run,
        benchmark: Benchmarks,
        include_unfinished_run: bool = False,
        parts: List[Tuple[int, int]] = None,
        func_numbers: List[int] = None,
        func_dims: List[int] = None,
    ):
        benchmark_run = list(enumerate(BENCHMARK_MAPPER[benchmark](1)))
        if parts:
            total_problems_run = []
            for part in parts:
                part_num, num_of_parts = part
                part_size = math.ceil(len(benchmark_run) / num_of_parts)
                total_problems_run += benchmark_run[
                    part_num * part_size : (part_num + 1) * part_size
                ]
            total_problems_run = list(set(total_problems_run))
        else:
            total_problems_run = benchmark_run

        if func_numbers:
            total_problems_run = [
                problem
                for problem in total_problems_run
                if problem[1].func_id in func_numbers
            ]

        if func_dims:
            total_problems_run = [
                problem
                for problem in total_problems_run
                if problem[1].dimension in func_dims
            ]

        return [
            (i, run, problem)
            for i, space in total_problems_run
            if (
                problem := (
                    space.suite,
                    space.func_id,
                    space.dimension,
                    space.func_instance,
                )
            )
            if not self.result_storage.is_run_exists(run, problem)
            or (
                include_unfinished_run
                and not self.result_storage.run_result(run, problem)
            )
            or (
                include_unfinished_run
                and self.result_storage.problem_final_value_for_run(run, problem)[0]
                < BUDGET - 1000
            )
        ]

    def command_line_for_missing_benchmark_run(
        self,
        run: Run,
        benchmark: Benchmarks,
        parts: List[Tuple[int, int]] = None,
        function_numbers: List[int] = None,
        function_dimensions: List[int] = None,
        stack_indexes: bool = True,
        include_unfinished_run: bool = False,
    ):
        benchmark_size = len(BENCHMARK_MAPPER[benchmark](1))
        missing_parts_of_run = self.missing_parts_of_benchmark_run(
            run,
            benchmark,
            include_unfinished_run,
            parts,
            function_numbers,
            function_dimensions,
        )

        if not missing_parts_of_run:
            return ["Finished"]
        if stack_indexes:
            missing_indexes_of_run = [i for i, *_ in missing_parts_of_run]
            missing_ranges = get_index_ranges(missing_indexes_of_run, benchmark_size)
        else:
            missing_ranges = [(i, benchmark_size) for i, *_ in missing_parts_of_run]

        missing_ranges = sorted(missing_ranges, key=lambda x: x[1], reverse=True)

        parts_to_run_command = " ".join(
            [
                f"-p {chunk_index} {chunk_size}"
                for (chunk_index, chunk_size) in missing_ranges
            ]
        )
        alg, run_name = run
        return [
            f"python run_on_all_devices.py {alg.value} -c 4"
            f" -b {benchmark.value} -n {run_name} {parts_to_run_command} --budget 50000"
        ]

    def _alg_progression(self, filtered_run_result, alg):
        return [
            (
                key,
                last_data[0] - first_data[0],
                first_data[1] - last_data[1],
                distance_between_points(first_data[2], last_data[2]),
            )
            for key, group in filtered_run_result
            if key == alg
            if (list_group := list(group))
            if (first_data := list_group[0])
            if (last_data := list_group[-1])
        ]

    def _each_alg_checkpoint(self, run: Run, problem: ProblemSpace, alg):
        run_result = self._run_normalize_result(run, problem)
        filtered_result = itertools.groupby(run_result, lambda x: x[3])
        result = self._alg_progression(filtered_result, alg)
        return result

    def average_alg_progression(self, run: Run):
        problems_for_run = (
            self.result_storage.common_spaces_on_suite_from_multiple_runs(
                [run], Suites.COCO
            )
        )
        algorithm_names = self.result_storage.all_algorithms_in_run(
            run, problems_for_run
        )  # TODO - not working in sql (api changed)

        for algorithm_name in algorithm_names:
            if algorithm_name not in Algorithms.__members__.keys():
                continue

            alg_run = list(
                itertools.zip_longest(
                    *[
                        self._each_alg_checkpoint(run, problem, algorithm_name)
                        for problem in problems_for_run
                    ]
                )
            )
            mean_alg = [
                (
                    sum(relevant_results) / len(relevant_results),
                    sum(relevant_budget) / len(relevant_budget),
                )
                for checkpoints in alg_run
                if (relevant_results := [d[2] for d in checkpoints if d])
                if (relevant_budget := [d[1] for d in checkpoints if d])
            ]

            real_alg_name, run_name = run
            self.output_manager.log_alg_progression(
                (Algorithms(algorithm_name.lower()), run_name),
                [a[0] for a in mean_alg],
                graph_name=f"new alg compare {real_alg_name}",
            )

    def draw_counter_map_of_run(
        self,
        runs: List[Run],
        problems: List[ProblemSpace],
        plot_names: List[str],
        dims: List[int],
        x_lower_bounds: int = -5,
        x_upper_bounds: int = 5,
        y_lower_bounds: int = -5,
        y_upper_bounds: int = 5,
        step_size: int = None,
        num_of_points: int = None,
        legend_size: int = 40,
        plots_in_row: int = None,
        graph_name: str = "",
    ):
        color_map = plt.cm.get_cmap("brg", len(runs))
        colors = color_map(np.linspace(0, 1, len(runs)))
        plots_in_row = plots_in_row or len(problems)

        n = len(problems)
        num_of_rows = math.ceil(n / plots_in_row)
        fig, axes = plt.subplots(
            num_of_rows,
            plots_in_row,
            figsize=(plots_in_row * 7, num_of_rows * 5),
            sharex=True,
        )
        axes = axes.flatten()
        if not isinstance(axes, np.ndarray):
            axes = np.array([axes])

        for i, problem in enumerate(problems):
            ax = axes[i]
            suite, func_number, dim, instance = problem
            environment = SUITES_MAPPING[suite](func_number, dim, instance, None)
            points = create_grid_points(
                x_lower_bounds,
                x_upper_bounds,
                y_lower_bounds,
                y_upper_bounds,
                1000,
                dim,
                list(dims),
            )
            grid_values = environment(points.detach(), debug_mode=True)
            x_grid, y_grid, grid_values = (
                points[:, 0].reshape(1000, 1000).detach(),
                points[:, 1].reshape(1000, 1000).detach(),
                grid_values,
            )

            ax.contourf(
                x_grid,
                y_grid,
                grid_values.reshape(x_grid.shape).cpu(),
            )

            for run, plot_name, color in zip(runs, plot_names, colors):
                self.logger.info(f"Contour for {run} with {color}")
                points_of_run = points_from_run(
                    self.result_storage.run_result(run, problem)
                )
                points_of_run = np.array(points_of_run)
                self.logger.info(f"Found {len(points_of_run)} in {run}")
                if step_size:
                    points_of_run = points_of_run[::step_size]
                if num_of_points:
                    points_of_run = points_of_run[:num_of_points]
                self.logger.info(f"Plotting for {len(points_of_run)} in {run}")
                ax.plot(
                    points_of_run[:, dims[0]],
                    points_of_run[:, dims[1]],
                    ".",
                    label=plot_name if i == 0 else None,
                    linestyle="-",
                    color=color,
                    linewidth=4 if len(problems) == 1 else 2,
                    markersize=20 if len(problems) == 1 else 10,
                )
                if len(axes) > 1:
                    ax.set_xlabel(chr(ord("a") + i), fontsize=6)
        fig.legend(fontsize=legend_size)
        fig.tight_layout()
        problem = problems[0]
        plot_name = (
            f"convergence {problem[0]}-{problem[1]}-{problem[2]}-{problem[3]} on {dims} final"
            if len(problems) == 1
            else f"convergence compare on {len(problems)} final"
        )
        self.output_manager.print_image(
            runs[0],
            fig,
            0,
            graph_name,
            plot_name,
        )

    def is_point_distance(
        self, run1: Run, run2: Run, func_id: int = None, func_dim: int = None
    ):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            [run1, run2]
        )
        if func_id:
            common_problems = [
                problem for problem in common_problems if problem[1] == func_id
            ]
        if func_dim:
            common_problems = [
                problem for problem in common_problems if problem[2] == func_dim
            ]
        for problem in common_problems:
            last_run1_step = self.result_storage.problem_final_value_for_run(
                run1, problem
            )[2]
            last_run2_step = self.result_storage.problem_final_value_for_run(
                run2, problem
            )[2]
            points_distance = distance_between_points(last_run2_step, last_run1_step)
            if points_distance > 0.5:
                self.logger.warning(
                    f"{run1} is far from {run2} on {problem}. distance: {points_distance}, "
                    f"points: {last_run1_step}, {last_run2_step}"
                )

    def distance_between_runs(
        self,
        run1: Run,
        run2: Run,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = "",
    ):
        self.logger.info(f"Distance between 2 runs {run1} and {run2}")
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            [run1, run2]
        )
        distance_between_points_for_all_runs = [
            self.distance_between_run_problem(problem, run1, run2)
            for problem in common_problems
        ]
        distance_analysis = measure_statistic_result_over_multiple_runs(
            distance_between_points_for_all_runs, analysis
        )
        self.output_manager.log_alg_progression(
            run1,
            distance_analysis,
            graph_name=f"Distance between {run1} and {run2} {graph_name}",
        )

    def distance_between_run_problem(self, problem, run1, run2) -> BudgetRun:
        run_result1 = self.result_storage.run_result(run1, problem)
        run_result2 = self.result_storage.run_result(run2, problem)
        budget_run1 = [
            (budget, point) for budget, value, point, alg_name in run_result1
        ]
        budget_run2 = [
            (budget, point) for budget, value, point, alg_name in run_result2
        ]
        max_used_budget = max_budget_of_runs([budget_run1, budget_run2])
        full_budget_run1 = from_mapped_result_to_steps(budget_run1, max_used_budget)
        full_budget_run2 = from_mapped_result_to_steps(budget_run2, max_used_budget)
        last_item = (
            full_budget_run1[-1]
            if len(full_budget_run1) > len(full_budget_run2)
            else full_budget_run2[-1]
        )
        return [
            (budget, distance_between_points(point1, point2))
            for budget, (point1, point2) in enumerate(
                itertools.zip_longest(
                    full_budget_run1, full_budget_run2, fillvalue=last_item
                )
            )
        ]

    def remove_run(self, run: Run):
        self.logger.info(f"Removing run {run}")
        problems_in_runs = (
            self.result_storage.common_spaces_on_suite_from_multiple_runs([run])
        )
        for problem in problems_in_runs:
            self.logger.info(f"Removing run {run} {problem}")
            self.result_storage.remove_run(run, problem)

    def list_runs(self) -> List[Run]:
        self.logger.info("Listing all runs")
        runs = self.result_storage.list_runs()

        run_names = list(set(run for run, alg in runs))
        return run_names

    def problems_far_from_best(
        self, run: Run, problems: List[ProblemSpace], eps: float
    ):
        return [
            (problem, distance_from_best)
            for problem in problems
            if (min_max_values := self.result_storage.min_max_from_space(problem))
            if (
                distance_from_best := (
                    self.result_storage.problem_final_value_for_run(run, problem)[1]
                    - min_max_values[0]
                )
                / (min_max_values[1] - min_max_values[0])
            )
            if distance_from_best > eps
        ]

    def _run_success(
        self, runs: List[Run], problems: List[ProblemSpace], min_distance: float = 1e-3
    ):
        return [
            1
            - (
                len(
                    [
                        problem
                        for problem, _ in self.problems_far_from_best(
                            run, problems, min_distance
                        )
                    ]
                )
                / len(problems)
            )
            for run in runs
        ]

    def print_bar_problems_solved(
        self, runs: List[Run], min_distance: float = 1e-3, plot_names: List[str] = None
    ):
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
        ) = self._normalized_results(runs, only_last=True)
        probabilities = (norm_results < min_distance).sum(axis=1) / len(common_problems)
        runs_success = dict(zip(plot_names, probabilities))
        std = 1 / np.sqrt(len(common_problems) * probabilities)

        self.output_manager.print_bars(
            runs_success,
            std,
            f"Problems bar on distance {min_distance} with {len(common_problems)} problems",
        )

    def where_am_i_far_from_best(self, run: Run, eps: float = 1e-3):
        problems = self.result_storage.common_spaces_on_suite_from_multiple_runs([run])
        unsolved_problems = self.problems_far_from_best(run, problems, eps)
        return {
            func_number: (
                len(problems_in_func),
                sum([distance for problem, distance in problems_in_func])
                / len(problems_in_func),
                problems_in_func,
            )
            for func_number, grouped_problems_in_func in itertools.groupby(
                sorted(unsolved_problems, key=lambda x: x[0][1]), lambda x: x[0][1]
            )
            if (problems_in_func := list(grouped_problems_in_func))
        }

    def where_am_i_worse(
        self, run1: Run, run2: Run, func_number: int = None, eps: float = 1e-5
    ):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            [run1, run2]
        )
        problems_where_2_is_worse = [
            (
                problem,
                last_step1[1] - last_step2[1],
                last_step1[2],
                last_step2[2],
            )
            for problem in common_problems
            if (func_number is None or func_number == problem[1])
            if (last_step1 := self._normalize_last_step(run1, problem))
            if (last_step2 := self._normalize_last_step(run2, problem))
            if last_step1[1] + eps < last_step2[1]
        ]
        problem_by_id_where_2_is_worse = {
            key: (
                problems,
                sum(distance_for_problems),
                problems[np.argmax(distance_for_problems)],
                max(distance_for_problems),
                problems[np.argmax(point_distance)],
                max(point_distance),
            )
            for key, group in itertools.groupby(
                sorted(problems_where_2_is_worse, key=lambda p: p[0][1]),
                key=lambda p: p[0][1],
            )
            if (problems_data := list(group))
            if (problems := [problem_data[0] for problem_data in problems_data])
            if (distance_for_problems := [p[1] for p in problems_data])
            if (
                point_distance := [
                    distance_between_points(problem[2], problem[3])
                    for problem in problems_data
                ]
            )
        }
        return problem_by_id_where_2_is_worse

    def budget_use_for_runs(self, runs: List[Run]):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs
        )
        budget_use = {
            f"{alg}_{run_name}": sum(budget) / len(budget)
            for alg, run_name in runs
            if (
                budget := [
                    self.result_storage.problem_final_value_for_run(
                        (alg, run_name), problem
                    )[0]
                    for problem in common_problems
                ]
            )
        }
        self.output_manager.print_bars(budget_use)

    def plot_difference_between_alg_best_points(
        self, alg1: Run, alg2: Run, problem: ProblemSpace
    ):
        best_point_for_alg1 = torch.tensor(
            self.result_storage.problem_final_value_for_run(alg1, problem)[2]
        )
        best_point_for_alg2 = torch.tensor(
            self.result_storage.problem_final_value_for_run(alg2, problem)[2]
        )
        direction = (best_point_for_alg2 - best_point_for_alg1) / 500
        suite, func_number, dim, instance = problem
        graph_name = (
            f"alg {alg1[0].value}-{alg1[1]} vs alg {alg2[0].value}-{alg2[1]} "
            f"on {suite.value}-{func_number}-{dim}-{instance}"
        )
        self.plot_vector_line(
            SUITES_MAPPING[suite](func_number, dim, instance),
            best_point_for_alg1.tolist(),
            direction,
            graph_name,
            graph_name,
        )

    def plot_vector_line(
        self,
        env: CallableSpace,
        mid_point: Point,
        direction: Tensor,
        plot_name: str,
        graph_name: str,
        num_of_positive_steps: int = 1000,
        num_of_negative_steps: int = 1000,
    ):
        points = torch.stack(
            [
                torch.tensor(mid_point, device=env.device) - t * direction
                for t in range(num_of_negative_steps, 0, -1)
            ]
            + [
                torch.tensor(mid_point, device=env.device) + t * direction
                for t in range(num_of_positive_steps)
            ]
        )
        values = env(points, debug_mode=True)
        results = list(
            zip(
                range(-num_of_negative_steps, num_of_positive_steps, 1), values.tolist()
            )
        )
        self.output_manager.log_normalize_progression(results, plot_name, graph_name)

    def final_point_of_run(
        self, run: Run, problem: ProblemSpace
    ) -> Tuple[Point, float]:
        final_point = self.result_storage.problem_final_value_for_run(run, problem)
        return final_point[2], final_point[1]

    def early_stopping_for_run(self, run: Run, problem: ProblemSpace):
        loss_data = self.result_storage.get_loss_data(run, problem)
        early_stopping = [
            (i, len(loss_epochs)) for i, loss_epochs in enumerate(loss_data)
        ]
        self.output_manager.log_normalize_progression(
            early_stopping, f"{run} {problem}", "early stopping"
        )

    def contour_pca_space(
        self, run: Run, env: CallableSpace, num_of_sample_points: int = 1000
    ):
        sampled_points = env.sample_from_space(
            num_of_sample_points**2, device=env.device
        )
        values = env(sampled_points, debug_mode=True)
        points_with_values = torch.cat((sampled_points, values.unsqueeze(1)), dim=1)

        pca_dim_reduction = PCA(n_components=2).fit_transform(
            points_with_values.cpu().numpy()
        )
        x_grid = pca_dim_reduction[:, 0]
        y_grid = pca_dim_reduction[:, 1]
        fig = plt.figure(figsize=[12.8, 9.6])
        ax = fig.add_axes([0, 0, 1, 1])
        ax.contourf(
            x_grid.reshape(num_of_sample_points, num_of_sample_points),
            y_grid.reshape(num_of_sample_points, num_of_sample_points),
            values.reshape(num_of_sample_points, num_of_sample_points).cpu(),
            cmap="Greys",
        )
        self.output_manager.print_image(run, fig, graph_name=f"PCA {env}", index=0)

    def print_metric(self, metric_name: str, graph_name: str):
        metric = self.result_storage.get_metric(metric_name)
        self.output_manager.log_normalize_progression(
            metric.to_numpy().tolist(), metric_name, graph_name
        )
        self.output_manager.finish()

    def compare_metric_by_budget(
        self,
        metric_name: str,
        graph_name: str,
        manipulation: MetricManipulation = MetricManipulation.NONE,
    ):
        metric_full_name = self.result_storage.get_metrics_types(metric_name)
        for metric in metric_full_name:
            self.logger.info(f"Logging {metric}")
            metric_results = self.result_storage.get_metric(metric)
            values = METRIC_MANIPULATION[manipulation](metric_results)
            metric_results_norm = list(zip(metric_results.iloc[:, 1].tolist(), values))
            self.output_manager.log_normalize_progression(
                metric_results_norm, metric, f"{graph_name}_{manipulation}"
            )
        self.output_manager.finish()

    def compare_best_by_dim(
        self,
        runs: List[Run],
        suite: Suites = None,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite
        )
        grouped_problems = itertools.groupby(
            common_problems, key=operator.itemgetter(2)
        )
        grouped_problems = {dim: list(problems) for dim, problems in grouped_problems}

        statistical_analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        values_for_runs = [
            [
                (
                    dim,
                    statistical_analyzer(
                        [
                            self.result_storage.normalize_final_value_for_run(
                                run, problem
                            )[1]
                            for problem in dim_problems
                        ]
                    ),
                )
                for dim, dim_problems in grouped_problems.items()
            ]
            for run in runs
        ]
        run1 = list(sorted(values_for_runs[0], key=lambda x: x[0]))
        run2 = list(sorted(values_for_runs[1], key=lambda x: x[0]))
        distance_between_run1_and_run2 = [
            (dim, abs(val1 - val2)) for (dim, val1), (_, val2) in zip(run1, run2)
        ]
        self.output_manager.log_normalize_progression(
            distance_between_run1_and_run2, graph_name=graph_name, plot_name="distance"
        )
        # for plot_name, runs_data in zip(plot_names, values_for_runs):
        #     runs_data = list(sorted(runs_data, key=lambda x: x[0]))
        #     self.output_manager.log_normalize_progression(
        #         runs_data, graph_name=graph_name, plot_name=plot_name
        #     )
        self.output_manager.finish()

    def _distance_between_points_for_run(
        self, run: Run, problem: ProblemSpace
    ) -> Tensor:
        points = points_from_run(self.result_storage.run_result(run, problem))
        return distance_between_many_tensors(
            torch.tensor(points[:-1]), torch.tensor(points[1:])
        )

    def _analyze_distance_between_points_for_run(
        self,
        run: Run,
        problems: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions,
    ) -> List[float]:
        analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        distances = [
            self._distance_between_points_for_run(run, problem) for problem in problems
        ]
        max_len_run = max(len(dist) for dist in distances)
        padded_distances = [
            torch.nn.functional.pad(dist, (0, max_len_run - len(dist)))
            for dist in distances
        ]
        return analyzer(torch.stack(padded_distances).T.tolist())

    def plot_steps_len(
        self,
        runs: List[Run],
        analysis: StatisticalAnalysisOptions,
        func_dim: int = None,
        func_id: int = None,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs,
            func_ids=[func_id] if func_id else None,
            func_dims=[func_dim] if func_dim else None,
        )
        self.logger.info(f"{len(problems)} common problems between {len(runs)} runs")
        analyzed_steps = [
            self._analyze_distance_between_points_for_run(run, problems, analysis)
            for run in runs
        ]
        self.logger.info(f"Analyzed steps")
        for run, plot_name, step_distance in zip(runs, plot_names, analyzed_steps):
            self.logger.info(
                f"Plotting {run} named {plot_name} with {len(step_distance)} steps"
            )
            self.output_manager.log_alg_progression(
                run, step_distance, graph_name=graph_name, plot_name=plot_name
            )
        self.output_manager.finish()

    def _calc_run_problem_terrain(
        self,
        run: Run,
        problem: ProblemSpace,
        num_of_points_between_steps: int = 4,
        # points_sample_from_neighbors: int = 10,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        device: int = None,
    ):
        points = torch.tensor(
            points_from_run(self.result_storage.run_result(run, problem)), device=device
        )
        t = torch.linspace(0, 1, num_of_points_between_steps, device=device)[:, None]
        diffs = points[1:] - points[:-1]
        interpolated_points = points[None, :-1, :] + t[..., None] * diffs[None, :, :]

        env = SUITES_MAPPING[problem[0]](problem[1], problem[2], problem[3], math.inf)
        values = env(interpolated_points).T
        analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        min_values = values.min(dim=-1)[0]
        return analyzer(values) - min_values

    def _calc_run_terrain(
        self,
        run: Run,
        problems: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions,
        terrain_analysis: StatisticalAnalysisOptions,
        device: int = None,
    ):
        terrains_difficulty = [
            self._calc_run_problem_terrain(
                run, problem, analysis=terrain_analysis, device=device
            )
            for problem in problems
        ]
        max_len_run = max(len(dist) for dist in terrains_difficulty)
        padded_distances = [
            torch.nn.functional.pad(dist, (0, max_len_run - len(dist)))
            for dist in terrains_difficulty
        ]
        analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        return analyzer(torch.stack(padded_distances).T)

    def _plot_steps_terrains(
        self,
        runs: List[Run],
        problems: List[ProblemSpace],
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        terrain_analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.QUANTILE_90,
        graph_name: str = None,
        plot_names: List[str] = None,
        device: int = None,
    ):
        self.logger.info(f"{len(problems)} common problems between {len(runs)} runs")
        run_terrain_difficulties = []
        for run in runs:
            self.logger.info(f"Plotting {run} on {len(problems)}")
            run_terrain_difficulties.append(
                self._calc_run_terrain(
                    run, problems, analysis, terrain_analysis, device
                )
            )
        for run, plot_name, difficulty in zip(
            runs, plot_names, run_terrain_difficulties
        ):
            self.output_manager.log_alg_progression(
                run,
                rolling_average(difficulty).cpu(),
                graph_name=f"Steps terrain {graph_name}",
                plot_name=plot_name,
            )
        self.output_manager.finish()

    def plot_steps_terrains(
        self,
        runs: List[Run],
        func_id: int = None,
        func_dim: int = None,
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        terrain_analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.QUANTILE_90,
        graph_name: str = None,
        plot_names: List[str] = None,
        device: int = None,
    ):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs,
            func_ids=[func_id] if func_id else None,
            func_dims=[func_dim] if func_dim else None,
        )
        self._plot_steps_terrains(
            runs,
            common_problems,
            analysis,
            terrain_analysis,
            graph_name,
            plot_names,
            device,
        )

    def plot_steps_terrains_by_func(
        self,
        runs: List[Run],
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        terrain_analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.QUANTILE_90,
        graph_name: str = None,
        plot_names: List[str] = None,
        device: int = None,
    ):
        pass

    def shrink_vs_move_plot(
        self,
        runs: List[Run],
        base_path: Path,
        bins: int,
        analysis: StatisticalAnalysisOptions,
        graph_name: str,
        plot_names: List[str] = None,
        device: int = None,
    ):
        self.logger.info(
            f"Shrink vs move plot for {runs} with {bins} bins analyze {analysis}"
        )
        analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        for (alg, run_name), plot_name in zip(runs, plot_names):
            self.logger.info(f"Calculation for {alg}-{run_name}")
            logs_run_path = create_file_log_path(base_path, alg.value, run_name)
            total_shrinking_stats = []
            total_moving_stats = []
            for log_file in logs_run_path.iterdir():
                relevant_logs = [
                    log
                    for log in log_file.read_text().splitlines()
                    if "new center is" in log
                ]
                if not relevant_logs:
                    continue
                buckets_size = math.ceil(len(relevant_logs) / bins)
                total_shrinking_stats.append(
                    [
                        len(
                            [
                                log
                                for log in relevant_logs[
                                    i * buckets_size : (i + 1) * buckets_size
                                ]
                                if "shrinking trust region" in log.lower()
                            ]
                        )
                        / len(relevant_logs)
                        for i in range(bins)
                    ]
                )
                total_moving_stats.append(
                    [
                        len(
                            [
                                log
                                for log in relevant_logs[
                                    i * buckets_size : (i + 1) * buckets_size
                                ]
                                if "moving trust region" in log.lower()
                            ]
                        )
                        / len(relevant_logs)
                        for i in range(bins)
                    ]
                )
            shrinking_matrix = torch.tensor(total_shrinking_stats, device=device)
            moving_matrix = torch.tensor(total_moving_stats, device=device)
            shrinking_mean = analyzer(shrinking_matrix.T)
            moving_mean = analyzer(moving_matrix.T)
            self.output_manager.log_alg_progression(
                (alg, run_name),
                shrinking_mean.tolist(),
                graph_name=graph_name,
                plot_name=f"Shrink {plot_name}",
            )
            self.output_manager.log_alg_progression(
                (alg, run_name),
                moving_mean.tolist(),
                graph_name=graph_name,
                plot_name=f"Moving {plot_name}",
            )
        self.output_manager.finish()

    def compare_dim_solving_bar(
        self,
        runs: List[Run],
        min_distance: float = 1e-3,
        bar_width: float = 0.2,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
            dim_split_index,
        ) = self._normalized_results_by_dim(runs, only_last=True)
        probabilities = {
            dim: (norm_results[:, dim_split_index[dim]] < min_distance).sum(axis=1)
            / len(dim_split_index[dim])
            for dim in dim_split_index
        }
        x = np.arange(len(probabilities))

        colormap = get_label_color_map(plot_names, self.colormap)
        fig, ax = plt.subplots(figsize=(10, 4))
        for i in range(len(runs)):
            run_probs = [solved_dims[i] for solved_dims in probabilities.values()]
            std = 1 / (
                np.array(run_probs)
                * np.array([len(dim_ids) for dim_ids in dim_split_index.values()])
            )
            ax.bar(
                x + i * bar_width,
                run_probs,
                width=bar_width,
                label=plot_names[i],
                yerr=std,
                error_kw={"elinewidth": 10, "capthick": 10},
                color=colormap[plot_names[i]],
            )

        ax.set_xlabel("Problem Dimension", fontsize=14)
        ax.set_ylabel("Success Rate", fontsize=14)
        ax.tick_params(which="major", labelsize=14)
        # ax.set_title("Success Rate by Problem Dimension and Algorithm")
        ax.set_xticks(x + bar_width * (len(runs) - 1) / 2)
        ax.set_xticklabels(list(probabilities.keys()))
        ax.set_ylim(0.5, 1)
        ax.legend(loc="lower right", ncol=len(runs) // 2)

        # Display the chart
        fig.tight_layout()
        self.output_manager.print_image(runs[0], fig, 0, graph_name, "solved_by_dim")

    def runs_distribution(
        self,
        runs: List[Run],
        bins: int = 20,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Runs distribution for {runs} with {bins} bins")
        problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(runs)
        self.logger.info(f"Found {len(problems)} common problems")

        normalized_values = [
            [self._normalize_last_step(run, problem)[1] for problem in problems]
            for run in runs
        ]
        self._plot_hist_for_runs(
            runs,
            normalized_values,
            bins,
            "Distribution of Error from Best Value",
            "Error from best value",
            graph_name,
            plot_names,
        )

    def _distributions_of_runs(self, runs: List[Run], min_diff: float):
        self.logger.info(f"Creating total_values for {runs}")
        problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(runs)
        self.logger.info(f"Found {len(problems)} common problems")

        normal_distributions = []
        total_values = []
        for run in runs:
            self.logger.info(f"Calculating for {run}")
            normalized_values = [
                self._normalize_last_step(run, problem)[1] for problem in problems
            ]
            total_values.append(normalized_values)
            min_values = np.array(
                [
                    self.result_storage.min_max_from_space(problem)[0]
                    for problem in problems
                ]
            )
            max_values = np.array(
                [
                    self.result_storage.min_max_from_space(problem)[1]
                    for problem in problems
                ]
            )
            normalized_values = np.array(normalized_values)
            solved_problems = (
                (normalized_values - min_values) / (max_values - min_values) < min_diff
            ).sum() / len(problems)
            normal_distributions.append(
                (
                    normalized_values.mean(),
                    normalized_values.std(),
                    np.median(normalized_values),
                    solved_problems,
                )
            )
        return total_values, normal_distributions

    def _kl_divergence_heatmap(
        self,
        runs: List[Run],
        values: List[List[float]],
        bins: int,
        plot_names: List[str],
        graph_name: str,
    ):
        histogram = np.array(
            [hist[0] for dist in values if (hist := np.histogram(dist, bins=bins))]
        )
        histogram = histogram / histogram.sum(axis=1)[:, None]
        histogram = histogram + 1 / (bins * 5)
        histogram = histogram / histogram.sum(axis=1)[:, None]

        n = len(histogram)
        kl_matrix = np.zeros((n, n))
        for i in range(n):
            for j in range(n):
                if i != j:
                    p = histogram[i]
                    q = histogram[j]
                    m = 0.5 * (p + q)
                    kl_matrix[i, j] = 0.5 * (entropy(p, m) + entropy(q, m))
                else:
                    kl_matrix[i, j] = 0

        labels = [name for name in plot_names]
        fig = Figure(figsize=(8, 6))
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111)
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_yticklabels(labels, rotation=0)

        sns.heatmap(
            kl_matrix,
            annot=True,
            fmt=".4f",
            cmap="viridis",
            cbar_kws={"label": "KL Divergence"},
            ax=ax,
            xticklabels=labels,
            yticklabels=labels,
        )
        ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
        ax.set_yticklabels(labels, rotation=0, fontsize=8)
        # ax.set_title("Pairwise KL Divergence Heatmap")
        fig.subplots_adjust(left=0.2, right=0.95, top=0.95, bottom=0.3)

        self.output_manager.print_image(runs[0], fig, 0, graph_name, "kl_divergence")

    def distributions_description_of_runs(
        self,
        runs: List[Run],
        min_diff: float = 1e-3,
        bins: int = 20,
        plot_names: List[str] = None,
        graph_name: str = None,
    ):
        values, distributions_descriptions = self._distributions_of_runs(runs, min_diff)
        self._kl_divergence_heatmap(runs, values, bins, plot_names, graph_name)
        return [
            f"{plot_name} mean: {mean}, std: {std}, median: {median}, problems_solved: {solved_problems}"
            for (mean, std, median, solved_problems), plot_name in zip(
                distributions_descriptions, plot_names
            )
        ]

    def _budget_progression_distributions(
        self,
        runs: List[Run],
        problems: List[ProblemSpace],
        finish_percentile: float = 0.01,
    ):
        results = [
            [self.result_storage.run_result(run, problem) for problem in problems]
            for run in runs
        ]
        return how_much_budget_to_reach_percentile(results, finish_percentile)

    def _plot_hist_for_runs(
        self,
        runs: List[Run],
        distributions: List[List[float]],
        bins: int,
        title: str,
        x_label: str,
        graph_name: str,
        plot_names: List[str],
    ):
        for run, distribution, plot_name in zip(runs, distributions, plot_names):
            self._plot_hist_for_run_values(
                run, distribution, bins, plot_name, title, x_label, graph_name
            )

    def _plot_hist_for_run_values(
        self,
        run: Run,
        values: List[float],
        bins: int,
        plot_name: str,
        title: str,
        x_label: str,
        graph_name: str,
    ):
        self.logger.info(f"Calculating for {run} - {plot_name}")
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.hist(values, bins=bins, label=plot_name)
        # ax.set_title(title)
        ax.set_ylabel("Value")
        ax.set_xlabel(x_label)
        ax.legend(loc="upper right")

        self.output_manager.print_image(
            run, fig, 0, graph_name=graph_name, plot_name=plot_name
        )

    def budget_progression_distributions(
        self,
        runs: List[Run],
        bins: int = 20,
        finish_percentile: float = 0.01,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        self.logger.info(f"Epoch progression distributions for {runs} with {bins} bins")
        problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(runs)
        self.logger.info(f"Found {len(problems)} common problems")

        finish_time_distributions = self._budget_progression_distributions(
            runs, problems, finish_percentile
        )

        self._plot_hist_for_runs(
            runs,
            finish_time_distributions,
            bins,
            f"Epochs to reach {finish_percentile * 100}% of best value",
            "Epochs",
            graph_name,
            plot_names,
        )
        self._kl_divergence_heatmap(
            runs, finish_time_distributions, bins, plot_names, graph_name
        )

        return [
            f"{plot_name} mean: {distribution.mean()}, std: {distribution.std()}, median: {np.median(distribution)}"
            for distribution, plot_name in zip(finish_time_distributions, plot_names)
        ]

    def _results_for_runs(
        self,
        runs: List[Run],
        suite: Suites = None,
        func_ids: List[int] = None,
        func_dims: List[int] = None,
        only_last: bool = False,
    ):
        common_problems = self.result_storage.common_spaces_on_suite_from_multiple_runs(
            runs, suite, func_ids, func_dims
        )
        results = []
        for run in runs:
            self.logger.info(f"Retrieving results for {run}")
            run_res = []
            for problem in tqdm.tqdm(common_problems):
                result = (
                    [self.result_storage.problem_final_value_for_run(run, problem)]
                    if only_last
                    else self.result_storage.run_result(run, problem)
                )
                run_res.append(result)
            results.append(run_res)
        return results, common_problems

    def _normalized_results(
        self,
        runs: List[Run],
        suite: Suites = None,
        func_ids: List[int] = None,
        func_dims: List[int] = None,
        only_last: bool = False,
    ):
        results, common_problems = self._results_for_runs(
            runs, suite, func_ids, func_dims, only_last
        )
        min_values = np.array(
            [
                self.result_storage.min_max_from_space(problem)[0]
                for problem in common_problems
            ]
        )
        max_values = np.array(
            [
                self.result_storage.min_max_from_space(problem)[1]
                for problem in common_problems
            ]
        )
        self.logger.info("Normalizing results")
        padded_results = np_array_from_results(results)
        shape_to_broadcast = list(padded_results.shape)
        shape_to_broadcast[1] = 1
        expanded_min = np.tile(
            min_values.reshape(1, len(min_values), 1), shape_to_broadcast
        )
        expanded_max = np.tile(
            max_values.reshape(1, len(max_values), 1), shape_to_broadcast
        )
        norm_results = (padded_results - expanded_min) / (expanded_max - expanded_min)
        return (
            norm_results[..., -1] if only_last else norm_results,
            results,
            min_values,
            max_values,
            common_problems,
        )

    def _normalized_results_by_dim(
        self,
        runs: List[Run],
        func_ids: List[int] = None,
        func_dims: List[int] = None,
        only_last: bool = False,
    ):
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
        ) = self._normalized_results(
            runs, func_ids=func_ids, func_dims=func_dims, only_last=only_last
        )
        dims = func_dims or sorted(list(set([dim for _, _, dim, _ in common_problems])))
        dim_index_split = {
            dim: [i for i, problem in enumerate(common_problems) if problem[2] == dim]
            for dim in dims
        }
        return (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
            dim_index_split,
        )

    def compare_runs_on_axis(
        self,
        runs: List[Run],
        axis: List[Tuple[AxisAnalysis, AxisAnalysis]],
        plot_names: List[str],
        graph_name: str,
        plots_in_rows: int = None,
        **kwargs,
    ):
        self.logger.info(f"Comparing {runs} on {graph_name}")
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
        ) = self._normalized_results(runs)
        self.logger.info("Retrieved results")
        if plots_in_rows:
            num_of_rows = math.ceil(len(axis) / plots_in_rows)
            fig, axs = plt.subplots(
                num_of_rows,
                plots_in_rows,
                figsize=(plots_in_rows * 10, num_of_rows * 5),
            )
            axs = axs.flatten()

        for j, (ax1, ax2) in enumerate(axis):
            axis1_data = AXIS_ANALYSIS[ax1]
            first_axis_value = axis1_data.analyze(
                results=results,
                norm_results=norm_results,
                min_values=min_values,
                max_values=max_values,
                **kwargs,
            )
            self.logger.info(f"Analyzed first axis {ax1}")
            axis2_data = AXIS_ANALYSIS[ax2]
            second_axis_value = axis2_data.analyze(
                results=results,
                norm_results=norm_results,
                min_values=min_values,
                max_values=max_values,
                **kwargs,
            )
            self.logger.info(f"Analyzed second axis {ax2}")

            df = pd.DataFrame(
                {"x": first_axis_value, "y": second_axis_value, "label": plot_names}
            )
            sns.set(style="darkgrid")
            if not plots_in_rows:
                fig, ax = plt.subplots(figsize=(10, 4))
            else:
                ax = axs[j]
            palette = sns.color_palette("hsv", len(df))
            sns.scatterplot(
                x="x",
                y="y",
                data=df,
                hue="label",
                palette=palette,
                s=200,
                ax=ax,
                legend=False,
            )
            for i in range(len(df)):
                ax.text(
                    df["x"][i],
                    df["y"][i],
                    df["label"][i],
                    fontsize=25,
                    va="center",
                    ha="left",
                )

            ax.set_xlabel(
                axis1_data.name(
                    results=results,
                    norm_results=norm_results,
                    min_values=min_values,
                    max_values=max_values,
                    **kwargs,
                ),
                fontsize=24,
            )
            ax.set_ylabel(
                axis2_data.name(
                    results=results,
                    norm_results=norm_results,
                    min_values=min_values,
                    max_values=max_values,
                    **kwargs,
                ),
                fontsize=24,
            )

            x_padding = (max(first_axis_value) - min(first_axis_value)) * 0.1
            y_padding = (max(second_axis_value) - min(second_axis_value)) * 0.1
            ax.set_xlim(
                min(first_axis_value) - x_padding, max(first_axis_value) + x_padding
            )
            ax.set_ylim(
                min(second_axis_value) - y_padding, max(second_axis_value) + y_padding
            )
            if not plots_in_rows:
                fig.tight_layout()
                self.logger.info("Plotted, Printing image...")
                self.output_manager.print_image(
                    runs[0],
                    fig,
                    0,
                    graph_name,
                    f"compare_runs_on_axis_{ax1.value}_{ax2.value}",
                )
        if plots_in_rows:
            self.logger.info("Plotted, Printing image...")
            fig.tight_layout()
            self.output_manager.print_image(
                runs[0], fig, 0, graph_name, "compare_multiple_axis"
            )

    def compare_steps_per_budget(
        self, runs: List[Run], graph_name: str, plot_names: List[str]
    ):
        results, common_problems = self._results_for_runs(runs)
        steps_per_budget = [
            [len(problem_run) / problem_run[-1][0] for problem_run in run]
            for run in results
        ]

        self._plot_hist_for_runs(
            runs,
            steps_per_budget,
            20,
            "Steps per budget Distribution",
            "Steps per budget",
            graph_name,
            plot_names,
        )

    def compare_steps_per_budget_overtime(
        self, runs: List[Run], graph_name: str, plot_names: List[str]
    ):
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
            dim_split_index,
        ) = self._normalized_results_by_dim(runs)
        steps_size = norm_results[..., :-1] - norm_results[..., 1:]
        results_by_dim = {
            dim: [
                [e for i, e in enumerate(res) if i in dim_split_index[dim]]
                for res in results
            ]
            for dim in dim_split_index
        }
        points_by_dim = np_array_from_res_split_by_dim(results_by_dim, 2)
        delta_by_dim = {
            dim: np.linalg.norm(points[..., 1:, :] - points[..., :-1, :], axis=-1)
            for dim, points in points_by_dim.items()
        }
        step_size_by_dim = {
            dim: steps_size[:, dim_split_index[dim]] for dim in dim_split_index
        }
        change_magnitude_by_dim = {
            dim: (step_size_by_dim[dim] / delta_by_dim[dim].clip(min=1e-6)).mean(axis=1)
            for dim in dim_split_index
        }

        fig, axes = plt.subplots(
            len(change_magnitude_by_dim), figsize=(10, len(change_magnitude_by_dim) * 5)
        )

        if len(change_magnitude_by_dim) == 1:
            axes = [
                axes
            ]  # In case there's only one subplot, wrap it in a list for consistency

        for idx, (dim, data) in enumerate(change_magnitude_by_dim.items()):
            for i in range(data.shape[0]):
                axes[idx].plot(data[i], label=plot_names[i])
            axes[idx].set_title(f"Dimension {dim}")
            axes[idx].set_xlabel("Budget")
            axes[idx].set_ylabel("Performance (log scale)")
            axes[idx].set_yscale("log")  # Set y-axis to log scale
            axes[idx].legend()
        fig.show()
        self.output_manager.print_image(
            runs[0], fig, 0, graph_name, "steps_per_budget_overtime"
        )

    def compare_by_dim_in_canvas(
        self,
        runs: List[Run],
        analysis: StatisticalAnalysisOptions = StatisticalAnalysisOptions.MEAN,
        dims: List[int] = None,
        graph_name: str = None,
        plot_names: List[str] = None,
    ):
        self.logger.info("Retrieved results")
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
            dim_split_index,
        ) = self._normalized_results_by_dim(runs, func_dims=dims)
        dims = dim_split_index.keys()
        analyzer = STATISTICAL_ANALYSIS_MAPPING.get(analysis)
        analyzed_results = analyzer(norm_results.transpose(0, 2, 1))
        self.logger.info("Analyzed results")
        n = len(dims)
        colormap = get_label_color_map(plot_names, self.colormap)
        fig, axes = plt.subplots(ncols=n, figsize=(n * 5, 5), sharex=True)
        for i, dim in enumerate(dims):
            results_in_dim = analyzed_results[:, dim_split_index[dim]]
            for run_idx, (sub_list, plot_name) in enumerate(
                zip(results_in_dim, plot_names)
            ):
                axes[i].plot(
                    sub_list,
                    color=colormap[plot_name],
                    label=plot_name if i == 0 else None,
                )

            axes[i].set_xlabel(f"Dimension {dim}")
            axes[i].set_yscale("log")
        fig.legend(fontsize=20)
        fig.tight_layout()
        self.output_manager.print_image(runs[0], fig, 0, graph_name, "solved_by_dim")

    def plot_run_convergence_by_dim(
        self, runs: List[Run], graph_name: str, plot_names: list[str]
    ):
        self.logger.info("Retrieved results")
        (
            norm_results,
            results,
            min_values,
            max_values,
            common_problems,
            dim_split_index,
        ) = self._normalized_results_by_dim(runs, only_last=True)
        results_by_dim = {
            dim: norm_results[:, dim_split_index[dim]] for dim in dim_split_index
        }
        mean_by_dim = {
            dim: results.mean(axis=1) for dim, results in results_by_dim.items()
        }
        percentile_25_by_dim = {
            dim: np.percentile(results, 25, axis=1)
            for dim, results in results_by_dim.items()
        }
        percentile_75_by_dim = {
            dim: np.percentile(results, 75, axis=1)
            for dim, results in results_by_dim.items()
        }
        x = np.array(list(dim_split_index.keys()))
        mean_results = np.stack(list(mean_by_dim.values())).transpose(1, 0)
        percentile_25 = np.stack(list(percentile_25_by_dim.values())).transpose(1, 0)
        percentile_75 = np.stack(list(percentile_75_by_dim.values())).transpose(1, 0)

        fig, ax = create_plot_with_statistics(
            x, mean_results, percentile_25, percentile_75, plot_names, self.colormap
        )
        ax.set_xlabel("Budget")
        ax.set_ylabel("Value")
        ax.legend()
        ax.grid(True)
        ax.set_yscale("log")

        fig.legend(fontsize=20)
        fig.tight_layout()
        self.output_manager.print_image(
            runs[0], fig, 0, graph_name, "convergence_by_dim"
        )
