from torch import Tensor
from torch.nn import Module

from algorithms.convergence_algorithms.opt_gan import OptGAN
from algorithms.convergence_algorithms.utils import sample_input_to_generator
from handlers.drawers.base_drawer import StaticDrawer


class WassersteinDistanceDrawer(StaticDrawer):
    def __init__(self, sample_size: int = 300, *args, **kwargs):
        super(WassersteinDistanceDrawer, self).__init__(*args, **kwargs)
        self.sample_size = sample_size

    def draw(self, alg: OptGAN, *args, **kwargs):
        explore_distance = self.calculate_discriminator_distance(
            alg,
            alg.explore_discriminator,
            alg.env.sample_from_space(self.sample_size).to(device=alg.device),
        )
        exploit_distance = self.calculate_discriminator_distance(
            alg, alg.exploit_discriminator, alg.sample_from_best_samples(self.sample_size)
        )
        return [
            (exploit_distance, "Exploit Wasserstein Distance"),
            (explore_distance, "Explore Wasserstein Distance"),
        ]

    def calculate_discriminator_distance(
        self, alg: OptGAN, discriminator: Module, true_distribution: Tensor
    ):
        gan_random_input = sample_input_to_generator(
            len(true_distribution), len(alg.env.lower_bound), device=alg.device
        )
        distribution_diff = discriminator(alg.generator(gan_random_input)) - discriminator(
            true_distribution
        )
        return distribution_diff.mean().item()


class OptGanModelsLossDrawer(StaticDrawer):
    def __init__(self, sample_size: int = 256):
        self.sample_size = sample_size

    def draw(self, alg: OptGAN, *args, **kwargs):
        exploit_loss = self.calculate_discriminator_loss(
            alg, alg.exploit_discriminator, alg.sample_from_best_samples(self.sample_size)
        )
        explore_loss = self.calculate_discriminator_loss(
            alg,
            alg.explore_discriminator,
            alg.env.sample_from_space(self.sample_size, alg.device),
        )
        gen_loss, gan_exp_loss, gan_explore_loss = self.calculate_gan_loss(alg)
        return [
            (explore_loss, "Explore Loss"),
            (exploit_loss, "Exploit Loss"),
            (gen_loss, "Gan Loss"),
            (gan_exp_loss, "Gan Exploit Loss"),
            (gan_explore_loss, "Gan Explore Loss"),
        ]

    def calculate_gan_loss(self, alg: OptGAN):
        alg.generator.eval()
        alg.explore_discriminator.eval()
        alg.exploit_discriminator.eval()

        noise = sample_input_to_generator(
            self.sample_size, len(alg.env.lower_bound), device=alg.device
        )
        generated_output = alg.query_generator_in_space(noise)
        return (
            (
                (1 / (1 + alg.ee_factor)) * alg.exploit_discriminator(generated_output)
                + (alg.ee_factor / (1 + alg.ee_factor))
                * alg.explore_discriminator(generated_output)
            )
            .mean()
            .item(),
            alg.exploit_discriminator(generated_output).mean().item(),
            alg.explore_discriminator(generated_output).mean().item(),
        )

    def calculate_discriminator_loss(
        self, alg: OptGAN, discriminator: Module, real_data: Tensor
    ):
        sampled_data_for_gan = sample_input_to_generator(
            len(real_data), len(alg.env.lower_bound), device=alg.device
        )
        fake_data = alg.generator(sampled_data_for_gan)
        return (discriminator(fake_data).mean() - discriminator(real_data).mean()).item()


class OptGANDatabaseSize(StaticDrawer):
    def draw(self, alg: OptGAN, *args, **kwargs):
        return [(float(len(alg.best_dataset_samples)), "Opt gan database size")]
