import logging
from logging import Logger
from typing import List, Dict

import numpy as np
import torch
from cma import CMAEvolutionStrategy
from torch import Tensor

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import (
    BoundedEvaluatedSamplerSpace,
    BoundedEvaluatedSpace,
)
from algorithms.convergence_algorithms.utils import (
    bind_space_with_input_mapping,
    TRUST_REGION_STOPPING_CONDITION,
)
from algorithms.mapping.base import InputMapping
from algorithms.mapping.trust_region import LinearTrustRegion
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.base import AlgorithmStopCondition
from algorithms.stopping_condition.trsut_region import StopAfterXTimes
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from utils.python import distance_between_points


class CMA(ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "cma"

    def __init__(
        self,
        cma_es: CMAEvolutionStrategy,
        env: BoundedEvaluatedSpace,
        initial_point: np.ndarray,
        input_mapping: InputMapping = None,
        logger: Logger = None,
    ):
        self.curr_best_point = initial_point
        self.cma_es = cma_es
        self.env = env
        self.input_mapping = input_mapping
        self.logger = logger or logging.getLogger(__name__)

    @property
    def device(self):
        return "cpu"

    @property
    def best_point_until_now(self):
        return torch.from_numpy(self.curr_best_point)

    @property
    def curr_point_to_draw(self):
        return torch.from_numpy(self.curr_best_point)

    @property
    def environment(self):
        return self.env

    @property
    def real_cma_evaluator(self):
        if self.input_mapping:
            return bind_space_with_input_mapping(self.env, self.input_mapping)
        return self.env

    @classmethod
    def from_space(
        cls,
        env: BoundedEvaluatedSamplerSpace,
        start_point: np.ndarray = None,
        input_mapping: InputMapping = None,
        logger: Logger = None,
    ):
        init_point = (
            start_point
            if start_point is not None
            else env.sample_from_space(1).squeeze().numpy()
        )
        return cls(
            CMAEvolutionStrategy(init_point, 0.5),
            env,
            init_point,
            input_mapping,
            logger=logger,
        )

    def _calculate_best_point(self, solutions):
        solutions_values = {
            tuple(x): self.real_cma_evaluator(torch.from_numpy(x), debug_mode=True)
            .cpu()
            .numpy()
            for x in solutions
        }
        solutions_values.update(
            {
                tuple(self.curr_best_point): self.real_cma_evaluator(
                    torch.from_numpy(self.curr_best_point), debug_mode=True
                )
                .cpu()
                .numpy()
            }
        )
        self.logger.info(f"Best solution value is {min(solutions_values.values())}")
        return np.array(min(solutions_values, key=solutions_values.get))

    def train(
        self,
        num_of_epoch_with_no_improvement: int = None,
        shrink_trust_region: bool = False,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        self.logger.info(f"starting algorithm cma for space {self.env}")
        callback_handlers = callback_handlers or []
        stopping_conditions = stopping_conditions or []

        # stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
        stopping_conditions.append(TRUST_REGION_STOPPING_CONDITION)
        if num_of_epoch_with_no_improvement:
            stopping_conditions.append(StopAfterXTimes(num_of_epoch_with_no_improvement))

        for handler in callback_handlers:
            handler.on_algorithm_start(self)

        points_found = np.array([])
        try:
            no_improvement_counter = 0
            while True:
                while not self.cma_es.stop():
                    self.logger.info(f"CMA - new iteration for {self.env}")
                    solutions = np.array(self.cma_es.ask())
                    solutions_value = (
                        self.real_cma_evaluator(torch.from_numpy(solutions)).cpu().numpy()
                    )

                    if len(points_found) > 0:
                        concat_solutions = np.concatenate((points_found, solutions))
                    else:
                        concat_solutions = solutions
                    points_found = np.unique(concat_solutions, axis=0)

                    self.cma_es.tell(solutions, solutions_value.tolist())
                    # self.cma_es.logger.add()  # write data to disc to be plotted
                    self.cma_es.disp()
                    best_point = self._calculate_best_point(solutions)
                    progress_value = self.real_cma_evaluator(
                        torch.from_numpy(self.curr_best_point), debug_mode=True
                    ) - self.real_cma_evaluator(torch.from_numpy(best_point), debug_mode=True)
                    if progress_value.item() <= 0:
                        self.logger.info(
                            f"No improvement {progress_value} "
                            f"counter {no_improvement_counter}/{num_of_epoch_with_no_improvement}"
                        )
                        no_improvement_counter += 1
                    else:
                        no_improvement_counter = 0
                        self.logger.info(
                            f"moved {distance_between_points(best_point.tolist(), self.curr_best_point.tolist())} "  # noqa
                            f"toward best point. progress {progress_value}"
                        )
                    self.curr_best_point = best_point
                    for handler in callback_handlers:
                        handler.on_epoch_end(
                            self,
                            database=torch.tensor(solutions),
                        )

                    for stop_condition in stopping_conditions:
                        if stop_condition.should_stop(self, counter=no_improvement_counter):
                            raise AlgorithmFinish(
                                stop_condition.REASON.format(
                                    alg="CMA",
                                    env=self.env,
                                    best_point=self.best_point_until_now,
                                    tr=self.input_mapping,
                                )
                            )
                real_best_point = self.input_mapping.inverse(self.curr_best_point)
                self.logger.info(
                    f"Shrinking CMA around {real_best_point} with {self.input_mapping}"
                )
                self.input_mapping.squeeze(real_best_point)
                self.curr_best_point = self.input_mapping.map(real_best_point)
                if not shrink_trust_region:
                    break
                self.cma_es = CMAEvolutionStrategy(self.curr_best_point, 0.5)
                for handler in callback_handlers:
                    handler.on_algorithm_update(self)
        except NoMoreBudgetError as e:
            self.logger.warning("Exceeded budget", exc_info=e)
        except AlgorithmFinish as e:
            self.logger.info(f"CMA Finish stopped {e}")
        finally:
            for handler in callback_handlers:
                handler.on_algorithm_end(self)
            if points_found.size:
                return points_found, self.real_cma_evaluator(
                    torch.from_numpy(points_found), debug_mode=True
                )
            return np.array([]), np.array([])

    def set_start_point(self, start_point: Tensor):
        if self.input_mapping:
            start_point = self.input_mapping.map(self.environment.normalize(start_point))
        start_point = start_point.cpu().numpy()
        self.cma_es = CMAEvolutionStrategy(start_point, 0.5)
        self.curr_best_point = start_point

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"input_mapping": LinearTrustRegion}
