from enum import Enum
from functools import cached_property
from typing import List

import torch
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from torch import Tensor
from torch.nn import Module

from algorithms.convergence_algorithms.opt_gan import OptGAN
from algorithms.convergence_algorithms.utils import sample_input_to_generator
from algorithms.space.base_space import EvaluatedSpace
from handlers.drawers.base_drawer import StaticPLTDrawer
from handlers.drawers.utils import pad_2d_to_nd_graph


class Discriminators(Enum):
    EXPLORATION = "exploration"
    EXPLOITATION = "exploitation"


class DiscriminatorContourDrawer(StaticPLTDrawer):
    def __init__(
        self,
        space: EvaluatedSpace,
        ranges: List[Tensor],
        discriminator: Discriminators,
        dims: list = None,
    ):
        super(DiscriminatorContourDrawer, self).__init__()
        self.dims = dims or [0, 1]
        self.dim_size = len(ranges)
        x_grid, y_grid = torch.meshgrid(ranges[self.dims[0]], ranges[self.dims[1]])
        self.x_grid = x_grid
        self.y_grid = y_grid
        self.space = space
        self.discriminator = discriminator

    def draw(self, alg: OptGAN, *args, **kwargs):
        disc = (
            alg.exploit_discriminator
            if self.discriminator == Discriminators.EXPLOITATION
            else alg.explore_discriminator
        )
        return [
            (
                self.create_counter_map_for_disc(alg, disc),
                f"{self.discriminator.value} discriminator",
            )
        ]

    @cached_property
    def grid_points(self):
        return pad_2d_to_nd_graph(self.x_grid, self.y_grid, tuple(self.dims), self.dim_size)

    @cached_property
    def func_grid_value(self):
        return self.space(self.grid_points)

    def scatter_data_on_fig(self, ax: Axes, opt_gan: OptGAN):
        gan_result = (
            opt_gan(
                sample_input_to_generator(
                    20, len(opt_gan.env.lower_bound), device=opt_gan.device
                )
            )
            .cpu()
            .detach()
        )
        ax.scatter(
            gan_result[:, self.dims[0]], gan_result[:, self.dims[1]], marker="x", color="b"
        )
        ax.scatter(
            opt_gan.best_point_until_now[self.dims[0]].cpu(),
            opt_gan.best_point_until_now[self.dims[1]].cpu(),
            marker="o",
            color="yellow",
        )

    def create_counter_map_for_disc(self, opt_gan: OptGAN, discriminator: Module):
        fig = plt.figure(figsize=[12.8, 9.6])
        ax = fig.add_axes([0, 0, 1, 1])
        points = self.grid_points.to(device=opt_gan.device)
        discriminator.eval()
        disc_values = discriminator(points).cpu()
        discriminator.train()
        ax.contourf(
            self.x_grid,
            self.y_grid,
            disc_values.reshape(self.x_grid.shape).detach(),
            cmap="Greys",
        )
        self.scatter_data_on_fig(ax, opt_gan)
        return fig


class GanDistributionDrawer(StaticPLTDrawer):
    def __init__(
        self,
        ranges: List[Tensor],
        dims: list = None,
        sample_size: int = 10_000,
        *args,
        **kwargs,
    ):
        super(GanDistributionDrawer, self).__init__(*args, **kwargs)
        self.dims = dims or [0, 1]
        self.dim_size = len(ranges)
        x_grid, y_grid = torch.meshgrid(ranges[self.dims[0]], ranges[self.dims[1]])
        self.x_grid = x_grid
        self.y_grid = y_grid
        self.sample_size = sample_size

    def draw(self, alg: OptGAN, *args, **kwargs):
        fig2 = plt.figure(figsize=[12.8, 9.6])
        ax2 = fig2.add_axes([0, 0, 1, 1])
        gan_input = sample_input_to_generator(
            self.sample_size, len(alg.env.lower_bound), device=alg.device
        )
        points = pad_2d_to_nd_graph(
            self.x_grid, self.y_grid, tuple(self.dims), self.dim_size
        ).to(device=alg.device)
        points_values = torch.zeros(points.shape[0])
        alg.generator.eval()
        gan_points = alg.query_generator_in_space(gan_input)
        for point_index in range(points.shape[0]):
            point = points[point_index].squeeze()
            points_values[point_index] = (point == gan_points).float().sum()

        ax2.hist2d(
            gan_points[:, 0].cpu().detach().numpy(), gan_points[:, 1].cpu().detach().numpy()
        )
        return [(fig2, "Gan distribution hist")]
