import math

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from compute_result.result_store.base import ResultStore
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.utils import problem_space_from_alg
from utils.algorithms_data import Algorithms


class SaveMinMaxHandler(AlgorithmCallbackHandler):
    def __init__(self, result_store: ResultStore):
        self.result_store = result_store
        self.min_point_value = math.inf
        self.max_point_value = -math.inf
        self.min_point = []
        self.max_point = []

    def current_point(self, alg: ConvergenceDrawable, *args, **kwargs):
        return alg.curr_point_to_draw.detach()

    def update_best_worst(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        best_point = self.current_point(alg, *args, **kwargs)

        current_point_value = alg.environment(best_point, debug_mode=True).item()
        if self.max_point_value < current_point_value:
            self.max_point_value = current_point_value
            self.max_point = best_point.tolist()
        if self.min_point_value > current_point_value:
            self.min_point_value = current_point_value
            self.min_point = best_point.tolist()

    def on_algorithm_start(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        (
            self.min_point_value,
            self.max_point_value,
            self.min_point,
            self.max_point,
        ) = self.result_store.min_max_from_space(problem_space_from_alg(alg))
        self.update_best_worst(alg, *args, **kwargs)

    def on_epoch_end(self, alg, *args, **kwargs):
        self.update_best_worst(alg, *args, **kwargs)

    def on_algorithm_update(self, alg, *args, **kwargs):
        self.update_best_worst(alg, *args, **kwargs)

    def on_algorithm_end(self, alg, *args, **kwargs):
        self.result_store.update_min(
            problem_space_from_alg(alg), self.min_point_value, self.min_point
        )
        self.result_store.update_max(
            problem_space_from_alg(alg), self.max_point_value, self.max_point
        )


class SaveRunToFileHandler(AlgorithmCallbackHandler):
    def __init__(
        self,
        run_name: str,
        algorithm: Algorithms,
        result_store: ResultStore,
    ):
        super(SaveRunToFileHandler, self).__init__()
        self.algorithm = algorithm
        self.run_name = run_name
        self.result_store = result_store
        self.results = []

    def point_to_add(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        return alg.best_point_until_now

    def add_new_data(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        best_point = self.point_to_add(alg, *args, **kwargs)
        best_point_value = alg.environment(best_point, debug_mode=True).item()
        self.results += [
            (
                alg.environment.used_budget,
                best_point_value,
                best_point.tolist(),
                alg.__class__.__name__,
            )
        ]

    def on_algorithm_start(self, alg, *args, **kwargs):
        self.result_store.remove_run(
            (self.algorithm, self.run_name), problem_space_from_alg(alg)
        )
        self.add_new_data(alg, *args, **kwargs)

    def on_epoch_end(self, alg, *args, **kwargs):
        self.add_new_data(alg, *args, **kwargs)

    def on_algorithm_update(self, alg, *args, **kwargs):
        self.add_new_data(alg, *args, **kwargs)

    def on_algorithm_end(self, alg, *args, **kwargs):
        self.add_new_data(alg, *args, **kwargs)
        self.result_store.store_run(
            (self.algorithm, self.run_name), problem_space_from_alg(alg), self.results
        )


class SaveRunLosses(AlgorithmCallbackHandler):
    def __init__(
        self,
        run_name: str,
        algorithm: Algorithms,
        result_store: ResultStore,
    ):
        super().__init__()
        self.algorithm = algorithm
        self.run_name = run_name
        self.result_store = result_store
        self.losses = []

    def on_algorithm_start(self, alg, *args, **kwargs):
        pass

    def on_epoch_end(self, alg, *args, test_losses=None, **kwargs):
        self.losses.append(test_losses)

    def on_algorithm_update(self, alg, *args, **kwargs):
        pass

    def on_algorithm_end(self, alg, *args, **kwargs):
        self.result_store.store_loss_data(
            (self.algorithm, self.run_name), problem_space_from_alg(alg), self.losses
        )
