import logging
from logging import Logger
from typing import List

from torch import Tensor

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.cma import CMA
from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.budget import EarlyBudgetStop
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class EglCmaFinisher(ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "egl_cma_finisher"

    def __init__(
        self,
        egl_algorithm: ConvergenceAlgorithm,
        cma_algorithm: CMA,
        logger: Logger = logging.getLogger(__name__),
    ):
        self.egl_algorithm = egl_algorithm
        self.cma_algorithm = cma_algorithm
        self.logger = logger
        self.egl_running = True

    def set_start_point(self, start_point: Tensor):
        raise NotImplementedError()

    @property
    def input_mapping(self):
        if self.egl_running:
            return self.egl_algorithm.input_mapping
        else:
            return self.cma_algorithm.input_mapping

    def train(self, egl_parameters: dict, callback_handlers: List[AlgorithmCallbackHandler]):
        try:
            self.egl_algorithm.train(
                **egl_parameters,
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
                stopping_conditions=[EarlyBudgetStop(30_000)],
            )
            self.egl_running = False
            best_point = self.egl_algorithm.input_mapping.inverse(
                self.egl_algorithm.best_point_until_now
            ).cpu()
            self.cma_algorithm.input_mapping.move_center(best_point)
            cma_starting_point = self.cma_algorithm.input_mapping.map(best_point).cpu()
            self.cma_algorithm = CMA.from_space(
                self.egl_algorithm.env,
                cma_starting_point.numpy(),
                self.cma_algorithm.input_mapping,
            )
            self.cma_algorithm.train(
                shrink_trust_region=True,
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
            )
        except NoMoreBudgetError as e:
            self.logger.warning("No more Budget", exc_info=e)
        except AlgorithmFinish as e:
            self.logger.info(f"{self.__class__.__name__} Finish stopped {e}")

        for handler in callback_handlers:
            self.logger.info(f"Calling upon {handler.on_algorithm_end} finishing convergence")
            handler.on_algorithm_end(self)

    @property
    def curr_point_to_draw(self):
        if self.egl_running:
            return self.egl_algorithm.curr_point_to_draw
        else:
            return self.cma_algorithm.curr_point_to_draw

    @property
    def environment(self):
        return self.egl_algorithm.environment

    @property
    def best_point_until_now(self):
        if self.egl_running:
            return self.egl_algorithm.best_point_until_now
        else:
            return self.cma_algorithm.best_point_until_now
