from typing import Callable, List

from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable


class NestedAlgorithmCallback(AlgorithmCallbackHandler):
    """
    When algorithm uses other algorithms and yet the callback should use the outer algorithm data
    You can wrap the callbacks with this class
    """

    def __init__(
        self,
        callback_handlers: List[AlgorithmCallbackHandler],
        algorithm: Callable[[], ConvergenceDrawable],
    ):
        self.callback_handlers = callback_handlers
        self.algorithm = algorithm

    def on_algorithm_start(self, alg, *args, **kwargs):
        outer_alg = self.algorithm()
        for c in self.callback_handlers:
            c.on_epoch_end(outer_alg, *args, **kwargs)

    def on_epoch_end(self, alg, *args, **kwargs):
        outer_alg = self.algorithm()
        for c in self.callback_handlers:
            c.on_epoch_end(outer_alg, *args, **kwargs)

    def on_algorithm_update(self, alg, *args, **kwargs):
        outer_alg = self.algorithm()
        for c in self.callback_handlers:
            c.on_algorithm_update(outer_alg, *args, **kwargs)

    def on_algorithm_end(self, alg, *args, **kwargs):
        self.on_algorithm_update(alg, *args, **kwargs)


class NoAlgorithmStartEndWrapper(AlgorithmCallbackHandler):
    def __init__(self, callback: AlgorithmCallbackHandler):
        self.callback = callback

    def on_algorithm_start(self, alg, *args, **kwargs):
        return self.callback.on_algorithm_update(alg, *args, **kwargs)

    def on_epoch_end(self, alg, *args, **kwargs):
        return self.callback.on_epoch_end(alg, *args, **kwargs)

    def on_algorithm_update(self, alg, *args, **kwargs):
        return self.callback.on_algorithm_update(alg, *args, **kwargs)

    def on_algorithm_end(self, alg, *args, **kwargs):
        return self.callback.on_epoch_end(alg, *args, **kwargs)
