import dataclasses
import math
from typing import List

import torch
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU, SmoothL1Loss
from torch.optim import Adam

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.cma import CMA
from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import (
    BoundedEvaluatedSamplerIdentifiableSpace,
)
from algorithms.convergence_algorithms.utils import (
    GOAL_IS_REACHED_STOPPING_CONDITION,
    NO_MORE_BUDGET_STOPPING_CONDITION,
)
from algorithms.nn.modules import BaseSequentialModel
from algorithms.convergence_algorithms.multiple_algorithm import (
    MultipleAlgorithmsWithEGLConvergence,
)
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.base import AlgorithmStopCondition
from algorithms.stopping_condition.budget import EarlyBudgetStop
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class MixinConvergenceAlgorithm(MultipleAlgorithmsWithEGLConvergence):
    ALGORITHM_NAME = "new"

    def _start_genetic_alg(
        self,
        genetic_algorithm,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler],
    ):
        genetic_algorithm.train(
            config.iteration_with_no_improvement_es,
            callback_handlers=[
                NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
            ],
        )

    def _create_genetic_alg(self, init_point: Tensor, config: dataclasses.dataclass):
        return CMA.from_space(
            self.space,
            init_point.cpu().numpy(),
            self.input_mapping,
            logger=self.logger,
        )

    def _create_convergence_algorithm(
        self, config: dataclasses.dataclass
    ) -> ConvergenceAlgorithm:
        dims = self.space.dimension
        grad_net = Sequential(
            Linear(dims, 10), ReLU(), Linear(10, 15), ReLU(), Linear(15, dims)
        ).to(device=self.device)
        model_net = Sequential(Linear(dims, 1, bias=False, dtype=torch.float64)).to(
            device=self.device
        )
        egl_adapted_net = BaseSequentialModel(model_net)
        if self.current_point is not None:
            egl_adapted_net = egl_adapted_net.from_parameter_tensor(self.current_point)

        output_mapping = AdaptedOutputUnconstrainedMapping(self.config.output_outlier)
        return EGLScheduler(
            self.space,
            grad_net,
            model_lr_factor=config.lr_factor,
            train_quantile=config.quantile,
            weights_creator=config.w_creator,
            model_to_train=egl_adapted_net,
            value_optimizer=Adam(
                grad_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
            ),
            model_to_train_optimizer=config.model_optimizer(
                model_net.parameters(), lr=config.model_lr, momentum=config.momentum
            ),
            epsilon=config.eps * math.sqrt(dims),
            epsilon_factor=config.eps_factor,
            min_epsilon=config.min_eps,
            perturb=config.perturb,
            grad_loss=SmoothL1Loss(reduction="none"),
            output_mapping=output_mapping,
            input_mapping=self.input_mapping,
            device=self.device,
            logger=self.logger,
        )


class EGLCMA(ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "cma_egl"

    def __init__(
        self, egl: EGL, cma: CMA, env: BoundedEvaluatedSamplerIdentifiableSpace
    ):
        self.egl = egl
        self.cma = cma
        self.env = env

    @property
    def input_mapping(self):
        return self.egl.input_mapping

    @property
    def logger(self):
        return self.egl.logger

    @property
    def curr_point_to_draw(self):
        return self.egl.curr_point_to_draw

    @property
    def environment(self):
        return self.egl.environment

    @property
    def best_point_until_now(self):
        return self.egl.best_point_until_now

    def train(
        self,
        epochs: int,
        exploration_size: int,
        num_loop_without_improvement: int,
        min_iteration_before_shrink: int,
        num_of_epoch_with_no_improvement: int = None,
        shrink_trust_region: bool = False,
        max_num_of_shrink: int = None,
        helper_model_training_epochs: int = 60,
        warmup_minibatch: int = 5,
        warmup_loops: int = 6,
        budget_for_cma: int = 30_000,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        try:
            stopping_conditions = stopping_conditions or []
            stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
            stopping_conditions.append(NO_MORE_BUDGET_STOPPING_CONDITION)

            for c in callback_handlers:
                c.on_algorithm_start(self)

            self.cma.train(
                num_of_epoch_with_no_improvement=num_of_epoch_with_no_improvement,
                shrink_trust_region=shrink_trust_region,
                stopping_conditions=[EarlyBudgetStop(budget_for_cma)],
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
            )
            curr_point = self.environment.denormalize(
                self.cma.input_mapping.inverse(self.cma.best_point_until_now)
            )
            self.egl.set_start_point(curr_point)
            self.logger.info(f"Starting EGL from {curr_point} with {self.egl}")
            self.egl.train(
                epochs,
                exploration_size,
                num_loop_without_improvement,
                min_iteration_before_shrink,
                max_num_of_shrink,
                helper_model_training_epochs,
                warmup_minibatch,
                warmup_loops,
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
            )
        except NoMoreBudgetError as e:
            self.logger.warning("CMAEGL is out of budget", exc_info=e)
        except AlgorithmFinish as e:
            self.logger.info(f"{self.__class__.__name__} stopped {e}")
        except Exception:
            self.logger.exception("Unexpected exception occurred")
            raise
        finally:
            for c in callback_handlers:
                c.on_algorithm_end(self)

    def set_start_point(self, start_point: Tensor):
        self.egl.set_start_point(start_point)
