import logging
from logging import Logger
from typing import List, Dict

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from tqdm.autonotebook import trange, tqdm

from algorithms.convergence_algorithms.base import Algorithm
from algorithms.convergence_algorithms.basic_config import FuncConfig
from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
from algorithms.convergence_algorithms.typing import BoundedEvaluatedSamplerSpace
from algorithms.convergence_algorithms.utils import (
    default_discriminator_optimizer,
    random_sampler_loader_from_tensor,
    sample_input_to_generator,
    default_generator_optimizer,
    GOAL_IS_REACHED_STOPPING_CONDITION,
)
from algorithms.mapping.base import InputMapping
from algorithms.mapping.trust_region import TanhTrustRegion
from algorithms.nn.losses import wgan_gradient_penalty_loss
from algorithms.nn.modules import (
    DiscriminatorByInputSize,
    GeneratorByDimSize,
)
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.stopping_condition.base import AlgorithmStopCondition
from handlers.base_handler import AlgorithmCallbackHandler
from handlers.drawers.drawable_algorithms import ConvergenceDrawable

DEFAULT_BATCH_SIZE = 30


class OptGAN(Module, ConvergenceDrawable, Algorithm):
    ALGORITHM_NAME = "opt_gan"

    def __init__(
        self,
        env: BoundedEvaluatedSamplerSpace,
        generator: Module,
        explore_discriminator: Module,
        exploit_discriminator: Module,
        best_dataset_samples: Tensor,
        gradient_penalty_factor: float = 0.1,
        explore_discriminator_opt: Optimizer = None,
        exploit_discriminator_opt: Optimizer = None,
        generator_opt: Optimizer = None,
        discriminator_iter: int = 4,
        generator_iter: int = 150,
        ee_factor: float = 0.3,
        input_mapping: InputMapping = None,
        device: int = None,
        logger: Logger = logging.getLogger(__name__),
    ):
        super(OptGAN, self).__init__()
        self.env = env
        self.generator = generator
        self.explore_discriminator = explore_discriminator
        self.exploit_discriminator = exploit_discriminator
        self.gradient_penalty_factor = gradient_penalty_factor
        self.explore_discriminator_opt = (
            explore_discriminator_opt
            or default_discriminator_optimizer(explore_discriminator)
        )
        self.exploit_discriminator_opt = (
            exploit_discriminator_opt
            or default_discriminator_optimizer(exploit_discriminator)
        )
        self.generator_opt = generator_opt or default_generator_optimizer(generator)
        self.best_dataset_samples = best_dataset_samples
        self.discriminator_iter = discriminator_iter
        self.generator_iter = generator_iter
        self.ee_factor = ee_factor
        self.input_mapping = input_mapping
        self.device = device
        self.logger = logger

    def set_start_point(self, start_point: Tensor):
        if self.input_mapping:
            start_point = self.input_mapping.map(start_point)
        self.best_dataset_samples = torch.cat(
            (self.best_dataset_samples, start_point.unsqueeze(0))
        )

    @property
    def best_point_until_now(self):
        return self.curr_point_to_draw

    @property
    def curr_point_to_draw(self):
        return self.best_dataset_samples[0]

    @property
    def environment(self):
        return self.env

    def forward(self, batch: Tensor):
        return self.query_generator_in_space(batch)

    def query_generator_in_space(self, batch: Tensor):
        return self.env.denormalize(self.generator(batch))

    def train(
        self,
        epochs: int = 2000,
        m: int = 30,
        a: float = 1.5,
        k_0: int = 150,
        max_fes: int = 5_000,
        batch_size: int = DEFAULT_BATCH_SIZE,
        count_best_data_multiplier: int = 1,
        iteration_with_no_improvement_before_squeeze: int = 10,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
        *args,
        **kwargs,
    ):
        try:
            self.logger.info(f"Started running {self.__class__.__name__}")
            callback_handlers = callback_handlers or []
            stopping_conditions = stopping_conditions or []
            stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)

            for handler in callback_handlers:
                handler.on_algorithm_start(self, database=self.best_dataset_samples)

            self.pre_iter_training(
                *args,
                batch_size=batch_size,
                callback_handlers=callback_handlers,
                **kwargs,
            )
            self.distribution_reshaping(
                epochs,
                m,
                a,
                k_0,
                max_fes,
                batch_size,
                count_best_data_multiplier,
                iteration_with_no_improvement_before_squeeze,
                stopping_conditions,
                callback_handlers,
            )
            self.logger.info(
                f"Finish running, best value is {self.best_point_until_now}"
            )
        except NoMoreBudgetError as e:
            self.logger.warning(
                f"You exceeded the budget allowed for running opt gan on {self.env}",
                exc_info=e,
            )
        except AlgorithmFinish as e:
            self.logger.info(f"OPT GAN Finish stopped {e}")
        finally:
            # Finish with handlers
            for handler in callback_handlers:
                handler.on_algorithm_end(self, database=self.best_dataset_samples)

    def distribution_reshaping(
        self,
        epochs: int,
        m: int,
        a: float,
        k_0: int,
        max_fes: int,
        batch_size: int,
        count_best_data_multiplier: int,
        iteration_with_no_improvement_before_squeeze: int,
        stopping_conditions: List[AlgorithmStopCondition] = None,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        self.logger.info(f"Started distribution reshaping for {epochs} epochs")
        no_improvement_counter = 0
        best_point: Tensor = self.best_dataset_samples[0].clone()
        last_disc_loss = None
        last_gan_loss = None
        optimum_values_loader = random_sampler_loader_from_tensor(
            self.best_dataset_samples,
            batch_size,
            self.generator_iter * self.discriminator_iter,
            count_best_data_multiplier,
        )
        for epoch in trange(epochs, position=0, leave=False):
            self.logger.info(
                f"Starting {epoch} / {epochs} with {self.generator_iter * self.discriminator_iter} "
                f"iterations on {self.env} with {self.input_mapping}"
            )
            for i, batch in tqdm(
                enumerate(optimum_values_loader), position=1, leave=False
            ):
                self.explore_discriminator.train()
                self.exploit_discriminator.train()
                self.generator.eval()

                self.explore_discriminator.zero_grad()
                self.exploit_discriminator.zero_grad()
                self.generator.zero_grad()
                noise = sample_input_to_generator(
                    batch_size, len(self.env.lower_bound), device=self.device
                )
                explore_generator_output = self.query_generator_in_space(noise)
                exploit_generator_output = self.query_generator_in_space(noise)
                x_uniform = (
                    self.input_mapping.sample_from_unbounded(
                        batch_size, device=self.device
                    )
                    if self.input_mapping
                    else self.env.sample_from_space(batch_size, device=self.device)
                )
                x_sampling = batch[0]
                explore_disc_loss = wgan_gradient_penalty_loss(
                    self.explore_discriminator,
                    explore_generator_output,
                    x_uniform,
                    self.gradient_penalty_factor,
                    batch_size,
                )
                exploit_disc_loss = wgan_gradient_penalty_loss(
                    self.exploit_discriminator,
                    exploit_generator_output,
                    x_sampling,
                    self.gradient_penalty_factor,
                    batch_size,
                )
                total_disc_loss = explore_disc_loss + exploit_disc_loss
                last_disc_loss = total_disc_loss.detach().clone()
                total_disc_loss.backward()
                self.explore_discriminator_opt.step()
                self.exploit_discriminator_opt.step()

                if i % self.discriminator_iter == 0 and i != 0:
                    self.exploit_discriminator.eval()
                    self.explore_discriminator.eval()
                    self.generator.train()

                    self.generator.zero_grad()
                    self.exploit_discriminator.zero_grad()
                    self.explore_discriminator.zero_grad()
                    noise = sample_input_to_generator(
                        batch_size, len(self.env.lower_bound), device=self.device
                    )
                    generated_output = self.query_generator_in_space(noise)

                    generator_loss = -(
                        (1 / (1 + self.ee_factor))
                        * self.exploit_discriminator(generated_output)
                        + (self.ee_factor / (1 + self.ee_factor))
                        * self.explore_discriminator(generated_output)
                    ).mean()
                    last_gan_loss = generator_loss.detach().clone()
                    generator_loss.backward()
                    self.generator_opt.step()

                    self.generator.eval()
            self.logger.info(
                f"Trained the generator and discriminator, "
                f"gan loss: {last_gan_loss}, discriminator loss: {last_disc_loss}"
            )
            noise = sample_input_to_generator(
                m, len(self.env.lower_bound), device=self.device
            )
            x_g = self.query_generator_in_space(noise)
            new_x_opt = torch.cat((x_g, self.best_dataset_samples))

            # shrink
            new_size = min(int(k_0 ** (1 - (a * epoch / max_fes))), new_x_opt.shape[0])
            self.logger.info(
                f"Shrinking database with k0: {k_0}, max fes: {max_fes}, new size {new_size}"
            )

            if self.input_mapping:
                real_new_x_opt = self.env.denormalize(
                    self.input_mapping.inverse(new_x_opt.detach())
                )
            else:
                real_new_x_opt = new_x_opt.detach()
            best_indices = self.env.best_k_indices(real_new_x_opt, new_size)
            self.best_dataset_samples = torch.tensor(
                [t.tolist() for t in new_x_opt[best_indices]], device=new_x_opt.device
            )

            new_best_point = real_new_x_opt[best_indices[0]].clone()

            # NOTE - We should not use debug mode but until we go public I use this hack
            best_value = self.env(best_point, debug_mode=True)
            new_best_value = self.env(new_best_point, debug_mode=True)
            self.logger.info(
                f"New point found {new_best_point} with value {new_best_value} "
                f"compare to old - {best_value} on {self.env}"
            )
            if best_value <= new_best_value:
                self.logger.info(
                    f"No new point found, counter {no_improvement_counter}"
                )
                no_improvement_counter += 1
            else:
                self.logger.info(
                    f"new point was found, distance: {((best_point - new_best_point) ** 2).sum().sqrt()}"
                )
                best_point = new_best_point

            if no_improvement_counter >= iteration_with_no_improvement_before_squeeze and self.input_mapping:
                self.logger.info(f"Squeeze trust region around {best_point}")
                no_improvement_counter = 0
                real_best_values = self.input_mapping.inverse(self.best_dataset_samples)
                unmapped_best_point = self.env.normalize(best_point)
                self.input_mapping.squeeze(unmapped_best_point)
                unreal_best_values = self.input_mapping.map(real_best_values)
                self.best_dataset_samples = unreal_best_values

            optimum_values_loader = random_sampler_loader_from_tensor(
                self.best_dataset_samples,
                batch_size,
                self.generator_iter * self.discriminator_iter,
                count_best_data_multiplier,
            )

            # Send to handler
            for handler in callback_handlers:
                handler.on_epoch_end(self, database=self.best_dataset_samples)

            for stop_condition in stopping_conditions:
                if stop_condition.should_stop(self, counter=no_improvement_counter):
                    raise AlgorithmFinish(
                        stop_condition.REASON.format(
                            alg="OPT GAN",
                            env=self.env,
                            best_point=self.best_point_until_now,
                            tr=self.input_mapping,
                        )
                    )

    def pre_iter_training(
        self,
        pre_iter_loops: int = 20,
        batch_size: int = DEFAULT_BATCH_SIZE,
        callback_handlers: List[AlgorithmCallbackHandler] = None,
    ):
        self.logger.info(f"Started pre training for {pre_iter_loops} loops")
        last_explore_loss = None
        last_gan_loss = None
        for loop in trange(pre_iter_loops, leave=False, position=0, desc="Loops"):
            for _ in trange(
                self.generator_iter, leave=False, position=1, desc="Generator iteration"
            ):
                self.explore_discriminator.train()
                self.generator.eval()
                for _ in trange(
                    self.discriminator_iter,
                    leave=False,
                    position=2,
                    desc="Discriminator iteration",
                ):
                    self.explore_discriminator_opt.zero_grad()
                    gen_random_input = sample_input_to_generator(
                        batch_size, len(self.env.lower_bound), device=self.device
                    )
                    x_g = self.query_generator_in_space(gen_random_input)
                    x_uniform = self.env.sample_from_space(
                        batch_size, device=self.device
                    )
                    loss = wgan_gradient_penalty_loss(
                        self.explore_discriminator,
                        x_g,
                        x_uniform,
                        self.gradient_penalty_factor,
                        batch_size,
                    )
                    loss.backward()
                    last_explore_loss = loss.detach().clone()
                    self.explore_discriminator_opt.step()
                self.generator.train()
                self.explore_discriminator.eval()
                self.generator_opt.zero_grad()
                gen_random_input = sample_input_to_generator(
                    batch_size, len(self.env.lower_bound), device=self.device
                )
                explore_disc_output = self.explore_discriminator(
                    self.query_generator_in_space(gen_random_input)
                )
                gen_loss = -explore_disc_output.mean()
                gen_loss.backward()
                last_gan_loss = gen_loss.detach().clone()
                self.generator_opt.step()

            self.logger.info(
                f"{loop} - Pre loop with {last_explore_loss} explore loss and gan loss {last_gan_loss}"
            )

            # Send to handler
            for handler in callback_handlers:
                handler.on_epoch_end(self, database=self.best_dataset_samples)

    def sample_from_best_samples(self, n_samples: int = 1):
        index_permutations = torch.randperm(len(self.best_dataset_samples))
        samples_idx = index_permutations[:n_samples]
        return self.best_dataset_samples[samples_idx]

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "explore_discriminator": FuncConfig(
                lambda device, dims, **kwargs: DiscriminatorByInputSize(dims).to(
                    device=device, dtype=torch.float64
                )
            ),
            "exploit_discriminator": FuncConfig(
                lambda device, dims, **kwargs: DiscriminatorByInputSize(dims).to(
                    device=device, dtype=torch.float64
                )
            ),
            "generator": FuncConfig(
                lambda device, dims, **kwargs: GeneratorByDimSize(dims).to(
                    device=device, dtype=torch.float64
                )
            ),
            "best_dataset_samples": FuncConfig(
                lambda device, env, **kwargs: env.normalize(
                    env.sample_for_optimum_points(1000, 200, device=device)
                )
                # lambda device, input_mapping, env, **kwargs: input_mapping.map(
                #     env.normalize(env.sample_for_optimum_points(1000, 200, device=device))
                # )
            ),
        }

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"input_mapping": TanhTrustRegion}
