import abc
from typing import List, Tuple, Union

from matplotlib.figure import Figure

from algorithms.space.base_space import BudgetLimitedSpace, EvaluatedSpace
from handlers.drawers.base_drawer import StaticPLTDrawer, Drawer
from handlers.drawers.drawable_algorithms import ConvergenceDrawable


class ValuePerBudget(StaticPLTDrawer):
    GRAPH_NAME = ""

    @abc.abstractmethod
    def best_point(self, alg: ConvergenceDrawable):
        raise NotImplementedError()

    def draw(self, alg: ConvergenceDrawable, *args, **kwargs):
        env = alg.environment
        if not isinstance(env, BudgetLimitedSpace) and isinstance(env, EvaluatedSpace):
            return []
        best_model_value = env(self.best_point(alg).cpu(), debug_mode=True)
        return [((best_model_value, env.used_budget), self.GRAPH_NAME)]


class BestPerBudget(ValuePerBudget):
    GRAPH_NAME = "Optimum Per Budget"

    def best_point(self, alg: ConvergenceDrawable):
        return alg.best_point_until_now


class CurrentValuePerBudget(ValuePerBudget):
    GRAPH_NAME = "Current Per Budget"

    def best_point(self, alg: ConvergenceDrawable):
        return alg.curr_point_to_draw


class EndPointValue(Drawer):
    def start_drawing(self, alg, *args, **kwargs):
        return []

    def end_drawing(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        env = alg.environment
        if not isinstance(env, BudgetLimitedSpace) and isinstance(env, EvaluatedSpace):
            return []
        best_model_value = env(alg.best_point_until_now.cpu(), debug_mode=True)
        return [(best_model_value.item(), "result")]

    def update_data(self, alg, *args, **kwargs):
        return []

    def draw_data(
        self, alg: ConvergenceDrawable, *args, **kwargs
    ) -> List[Tuple[Union[Figure], str]]:
        return []
