import dataclasses
from typing import List

import torch
from torch import Tensor

from algorithms.nn.modules import (
    DiscriminatorByInputSize,
    GeneratorByDimSize,
)

from algorithms.convergence_algorithms.multiple_algorithm import (
    MultipleAlgorithmsWithEGLConvergence,
)
from algorithms.convergence_algorithms.opt_gan import OptGAN
from algorithms.stopping_condition.trsut_region import StopAfterXTimes
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.nested_algorithm_handler import NoAlgorithmStartEndWrapper


class EGLOptGAN(MultipleAlgorithmsWithEGLConvergence):
    ALGORITHM_NAME = "egl_opt"

    def _create_genetic_alg(self, init_point: Tensor, config: dataclasses.dataclass):
        dims = self.space.dimension
        explore_disc = DiscriminatorByInputSize(dims).to(
            device=self.device, dtype=torch.float64
        )
        exploit_disc = DiscriminatorByInputSize(dims).to(
            device=self.device, dtype=torch.float64
        )
        generator = GeneratorByDimSize(dims).to(device=self.device, dtype=torch.float64)

        database_size = self.config.database_size
        best_database_samples = self.space.free_to_check_data(self.device)

        # NOTE - This is a cheat, I use it for research purposes only but in the real alg I cant use debug mode.
        #        Rather I have to save the values myself (perhaps in the space class)
        if len(best_database_samples) > 0:
            best_database_samples = self.space.best_k_values(
                best_database_samples,
                min(database_size, len(best_database_samples)),
                debug_mode=True,
            )
        if len(best_database_samples) < database_size:
            best_database_samples = torch.cat(
                (
                    best_database_samples,
                    self.space.sample_from_space(
                        database_size - len(best_database_samples), self.device
                    ),
                )
            )

        return OptGAN(
            self.space,
            generator,
            explore_disc,
            exploit_disc,
            gradient_penalty_factor=self.config.g_factor,
            discriminator_iter=self.config.d_iter,
            generator_iter=self.config.g_iter,
            ee_factor=self.config.ee_factor,
            best_dataset_samples=self.input_mapping.map(
                self.space.normalize(best_database_samples)
            ),
            input_mapping=self.input_mapping,
            device=self.device,
            logger=self.logger,
        )

    def _start_genetic_alg(
        self,
        genetic_algorithm: OptGAN,
        config: dataclasses.dataclass,
        callback_handlers: List[AlgorithmCallbackHandler],
    ):
        genetic_algorithm.train(
            pre_iter_loops=config.pre_iter_loops,
            batch_size=config.batch_size,
            epochs=config.epochs,
            k_0=int(config.database_size * 0.75),
            m=120,
            count_best_data_multiplier=config.best_data_multiplier,
            iteration_with_no_improvement_before_squeeze=config.iteration_with_no_improvement_before_squeeze
            + 1,
            stopping_conditions=[
                StopAfterXTimes(config.iteration_with_no_improvement_before_squeeze)
            ],
            callback_handlers=[
                NoAlgorithmStartEndWrapper(handler) for handler in callback_handlers
            ],
        )
