import abc
import copy
import dataclasses
import logging
import math
from abc import ABC
from logging import Logger
from typing import List

import torch
from torch import Tensor
from torch.nn import SmoothL1Loss, Sequential, Linear, ReLU
from torch.optim import Adam

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerIdentifiableSpace
from algorithms.convergence_algorithms.utils import (
    GOAL_IS_REACHED_STOPPING_CONDITION,
    NO_MORE_BUDGET_STOPPING_CONDITION,
)
from algorithms.mapping.trust_region import StaticShrinkingTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.modules import BaseSequentialModel
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.base import AlgorithmStopCondition
from algorithms.stopping_condition.trsut_region import TrustRegionStopCondition
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class MultipleAlgorithms(ConvergenceDrawable, Algorithm, ABC):
    def __init__(
        self,
        space: BoundedEvaluatedSamplerIdentifiableSpace,
        init_point: Tensor,
        config: dataclasses.dataclass,
        device: int = None,
        logger: Logger = None,
    ):
        self.space = space
        self.input_mapping = StaticShrinkingTrustRegion(
            space.lower_bound,
            space.upper_bound,
            config.tr_shrink_factor,
            min_trust_region_size=config.min_tr_size,
            device=device,
        )
        self.current_point = init_point
        self.config = config
        self.device = device
        self.logger = logger or logging.getLogger(__name__)

    def set_start_point(self, start_point: Tensor):
        raise NotImplementedError()

    @property
    def curr_point_to_draw(self):
        return self.current_point

    @property
    def environment(self):
        return self.space

    @property
    def best_point_until_now(self):
        return self.current_point

    @abc.abstractmethod
    def _create_convergence_algorithm(self, config: dataclasses.dataclass):
        raise NotImplementedError()

    @abc.abstractmethod
    def _create_genetic_alg(self, init_point: Tensor, config: dataclasses.dataclass):
        raise NotImplementedError()

    @abc.abstractmethod
    def _start_genetic_alg(
        self,
        genetic_algorithm,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler],
    ):
        raise NotImplementedError()

    @abc.abstractmethod
    def _start_convergence_algorithm(
        self,
        convergence_algorithm,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler],
    ):
        raise NotImplementedError()

    def train(
        self,
        config: dataclasses.dataclass,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        torch.set_default_dtype(torch.float64)
        stopping_conditions = stopping_conditions or []
        stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
        stopping_conditions.append(NO_MORE_BUDGET_STOPPING_CONDITION)

        convergence_alg = self._create_convergence_algorithm(config)
        self.logger.info(f"Creating {convergence_alg.__class__.__name__} algorithm")

        self.logger.info(f"Started training {self.__class__.__name__} algorithm")
        for c in callback_handlers:
            c.on_algorithm_start(self)

        try:
            for epoch in range(config.total_epochs):
                self.logger.info(f"starting {epoch} iteration")
                best_point_for_cma = self.input_mapping.map(
                    self.input_mapping.inverse(self.current_point), normalize=False
                )
                genetic_evolution_alg = self._create_genetic_alg(best_point_for_cma, config)
                self._start_genetic_alg(genetic_evolution_alg, config, callback_handlers)
                self.logger.info(
                    f"{epoch} Finish working using ES algorithm, starting convergence {self.space}"
                )

                unreal_best_point_normalized = self.input_mapping.map(
                    self.input_mapping.inverse(
                        genetic_evolution_alg.best_point_until_now, normalize=False
                    )
                )
                self.current_point = unreal_best_point_normalized
                self._start_convergence_algorithm(convergence_alg, config, callback_handlers)
                self.current_point = convergence_alg.best_point_until_now

                for stop_condition in stopping_conditions:
                    if stop_condition.should_stop(self):
                        raise AlgorithmFinish(
                            stop_condition.REASON.format(
                                alg=self.__class__.__name__,
                                env=self.space,
                                best_point=self.best_point_until_now,
                                tr=self.input_mapping,
                            )
                        )
        except NoMoreBudgetError as e:
            self.logger.warning("Mixin Got a weird error", exc_info=e)
        except AlgorithmFinish as e:
            self.logger.info(f"{self.__class__.__name__} stopped {e}")
        except Exception:
            self.logger.exception("Unexpected exception occurred")
            raise
        finally:
            for c in callback_handlers:
                c.on_algorithm_end(self)


class MultipleAlgorithmsWithEGLConvergence(MultipleAlgorithms, ABC):
    def _create_convergence_algorithm(
        self, config: dataclasses.dataclass
    ) -> ConvergenceAlgorithm:
        dims = self.space.dimension
        grad_net = Sequential(
            Linear(dims, 10), ReLU(), Linear(10, 15), ReLU(), Linear(15, dims)
        ).to(device=self.device)
        model_net = Sequential(Linear(dims, 1, bias=False, dtype=torch.float64)).to(
            device=self.device
        )
        egl_adapted_net = BaseSequentialModel(model_net)
        if self.current_point is not None:
            egl_adapted_net = egl_adapted_net.from_parameter_tensor(self.current_point)

        output_mapping = AdaptedOutputUnconstrainedMapping(self.config.output_outlier)
        return EGL(
            self.space,
            grad_net,
            egl_adapted_net,
            Adam(grad_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04),
            config.model_optimizer(
                model_net.parameters(), lr=config.model_lr, momentum=config.momentum
            ),
            epsilon=config.eps * math.sqrt(dims),
            epsilon_factor=self.config.eps_factor,
            min_epsilon=self.config.min_eps,
            perturb=self.config.perturb,
            grad_loss=SmoothL1Loss(),
            output_mapping=output_mapping,
            input_mapping=self.input_mapping,
            device=self.device,
            logger=self.logger,
        )

    def _start_convergence_algorithm(
        self,
        convergence_algorithm: ConvergenceAlgorithm,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler],
    ):
        convergence_algorithm.model_to_train = (
            convergence_algorithm.model_to_train.from_parameter_tensor(self.current_point)
        )
        convergence_algorithm.best_model = copy.deepcopy(convergence_algorithm.model_to_train)

        convergence_algorithm.train(
            config.epochs,
            config.explore_size,
            config.num_loop_without_improvement,
            config.min_iter,
            warmup_minibatch=config.warmup_minibatch,
            warmup_loops=config.warmup_loops,
            use_existed_data=config.use_existed_data,
            max_num_of_shrink=1,
            stopping_conditions=[TrustRegionStopCondition()],
            callback_handlers=[
                NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
            ],
        )
