import dataclasses
import logging
from logging import Logger
from typing import List

import torch
from torch import Tensor

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.cma import CMA
from algorithms.convergence_algorithms.opt_gan import OptGAN
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerIdentifiableSpace
from algorithms.mapping.trust_region import StaticShrinkingTrustRegion
from algorithms.nn.modules import (
    GeneratorByDimSize,
    DiscriminatorByInputSize,
)
from algorithms.space.exceptions import NoMoreBudgetError
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class CmaOptGANAlgorithm(ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "cma_opt"

    def __init__(
        self,
        space: BoundedEvaluatedSamplerIdentifiableSpace,
        init_point: Tensor,
        config: dataclasses.dataclass,
        device: int = None,
        logger: Logger = None,
    ):
        self.space = space
        self.input_mapping = StaticShrinkingTrustRegion(
            space.lower_bound, space.upper_bound, device=device
        )
        self.current_point = init_point
        self.config = config
        self.device = device
        self.logger = logger or logging.getLogger(__name__)

    def set_start_point(self, start_point: Tensor):
        raise NotImplementedError()

    @property
    def curr_point_to_draw(self):
        return self.current_point

    @property
    def environment(self):
        return self.space

    @property
    def best_point_until_now(self):
        return self.current_point

    def _create_opt_gan(self):
        dims = self.space.dimension
        explore_disc = DiscriminatorByInputSize(dims).to(
            device=self.device, dtype=torch.float64
        )
        exploit_disc = DiscriminatorByInputSize(dims).to(
            device=self.device, dtype=torch.float64
        )
        generator = GeneratorByDimSize(dims).to(device=self.device, dtype=torch.float64)

        database_size = self.config.base_database_size
        best_database_samples = self.space.free_to_check_data(self.device)

        # NOTE - This is a cheat, I use it for research purposes only but in the real alg I cant use debug mode.
        #        Rather I have to save the values myself (perhaps in the space class)
        best_database_samples = self.space.best_k_values(
            best_database_samples,
            min(database_size, len(best_database_samples)),
            debug_mode=True,
        )
        if len(best_database_samples) < database_size:
            best_database_samples = torch.cat(
                (
                    best_database_samples,
                    self.space.sample_from_space(
                        database_size - len(best_database_samples), self.device
                    ),
                )
            )

        return OptGAN(
            self.space,
            generator,
            explore_disc,
            exploit_disc,
            gradient_penalty_factor=self.config.g_factor,
            discriminator_iter=self.config.d_iter,
            generator_iter=self.config.g_iter,
            ee_factor=self.config.ee_factor,
            best_dataset_samples=self.input_mapping.map(
                self.space.normalize(best_database_samples)
            ),
            input_mapping=self.input_mapping,
            device=self.device,
            logger=self.logger,
        )

    def _create_genetic_alg(self, init_point: Tensor, with_trust_region: bool = True):
        return CMA.from_space(
            self.space,
            init_point.cpu().numpy(),
            self.input_mapping if with_trust_region else None,
            logger=self.logger,
        )

    def train(
        self,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        torch.set_default_dtype(torch.float64)
        self.logger.info("Creating opt_gan algorithm")
        genetic_evolution_alg = self._create_genetic_alg(self.current_point)

        self.logger.info("Started training mixin algorithm")
        for c in callback_handlers:
            c.on_algorithm_start(self)

        try:
            genetic_evolution_alg.train(
                config.iteration_with_no_improvement_es,
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
            )
            self.logger.info(
                f"Finish working using CMA algorithm, starting convergence {self.space}"
            )

            opt_gan = self._create_opt_gan()
            opt_gan.train(
                pre_iter_loops=config.pre_iter_loops,
                batch_size=config.batch_size,
                epochs=config.epochs,
                k_0=int(config.database_size * 0.75),
                m=120,
                count_best_data_multiplier=config.best_data_multiplier,
                iteration_with_no_improvement_before_squeeze=config.iteration_with_no_improvement_before_squeeze,
                callback_handlers=[
                    NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
                ],
            )
            self.current_point = opt_gan.best_point_until_now
        except NoMoreBudgetError as e:
            self.logger.warning("Mixin Got a weird error", exc_info=e)
        except Exception:
            self.logger.exception("Unexpected exception occurred")
            raise
        finally:
            for c in callback_handlers:
                c.on_algorithm_end(self)
