import dataclasses
import logging
from logging import Logger
from typing import List

from torch import Tensor

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.cma import CMA
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerIdentifiableSpace
from algorithms.convergence_algorithms.utils import (
    GOAL_IS_REACHED_STOPPING_CONDITION,
    NO_MORE_BUDGET_STOPPING_CONDITION,
)
from algorithms.stopping_condition.base import AlgorithmStopCondition
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class LoopCMA(ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "cma_loop"

    def __init__(
        self,
        space: BoundedEvaluatedSamplerIdentifiableSpace,
        init_point: Tensor,
        config: dataclasses.dataclass,
        logger: Logger = None,
    ):
        self.space = space
        self.current_point = init_point
        self.config = config
        self.logger = logger or logging.getLogger(__name__)

    def set_start_point(self, start_point: Tensor):
        raise NotImplementedError()

    @property
    def curr_point_to_draw(self):
        return self.current_point

    @property
    def environment(self):
        return self.space

    @property
    def best_point_until_now(self):
        return self.current_point

    def train(
        self,
        config: dataclasses.dataclass,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        try:
            stopping_conditions = stopping_conditions or []
            stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
            stopping_conditions.append(NO_MORE_BUDGET_STOPPING_CONDITION)

            while True:
                self.logger.info("New iteration of cma")
                cma = CMA.from_space(
                    self.space,
                    self.current_point.cpu().numpy(),
                    logger=self.logger,
                )
                cma.train(
                    callback_handlers=[
                        NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                    ],
                )
                self.current_point = cma.best_point_until_now

                for stop_condition in stopping_conditions:
                    if stop_condition.should_stop(self):
                        raise AlgorithmFinish(
                            stop_condition.REASON.format(
                                alg=self.__class__.__name__,
                                env=self.space,
                                best_point=self.best_point_until_now,
                            )
                        )
        except AlgorithmFinish as e:
            self.logger.info(f"{self.__class__.__name__} stopped {e}")

        for c in callback_handlers:
            c.on_algorithm_end(self)
