from abc import ABC
from typing import Tuple, List, Union, Optional

import torch
from matplotlib.collections import PathCollection
from matplotlib.figure import Figure
from torch import Tensor

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from handlers.drawers.base_drawer import PLTDrawer, StaticPLTDrawer
from handlers.drawers.drawable_algorithms import ConvergenceDrawable
from handlers.drawers.utils import create_grid_points


class ConvergenceDrawer(PLTDrawer, ABC):
    def __init__(
        self, dims: List[int] = None, dim_size: int = 2, map_output: bool = True, **kwargs
    ):
        super(ConvergenceDrawer, self).__init__(**kwargs)
        self.dims = dims or [0, 1]
        self.dim_size = dim_size
        self.map_output = map_output

    def start_drawing(self, alg, *args, **kwargs):
        x_grid, y_grid, grid_values = self.create_grid(alg)
        if self.map_output and getattr(alg, "output_mapping", None) is not None:
            grid_values = alg.output_mapping.map(grid_values.to(device=alg.device))
        self.ax.contourf(
            x_grid.cpu(),
            y_grid.cpu(),
            grid_values.reshape(x_grid.shape).cpu(),
            cmap="Greys",
        )
        params = self.new_point(alg).cpu()
        self.ax.plot(
            [params[self.dims[0]]], [params[self.dims[1]]], linestyle=":", marker="x", color="b"
        )
        return [(self.fig, "")]

    def update_data(self, egl: ConvergenceDrawable, *args, **kwargs):
        if self.map_output:
            return self.start_drawing(egl, *args, **kwargs)
        params = self.new_point(egl).cpu()
        self.ax.plot(
            [params[self.dims[0]]], [params[self.dims[1]]], linestyle=":", marker="x", color="r"
        )
        return [(self.fig, "")]

    def draw_data(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        params = self.new_point(alg).cpu()
        self.ax.plot(
            [params[self.dims[0]]], [params[self.dims[1]]], linestyle=":", marker="x", color="g"
        )
        return [(self.fig, "")]

    def end_drawing(self, alg, *args, **kwargs) -> List[Tuple[Union[Figure], str]]:
        return self.draw_data(alg, *args, **kwargs)

    def new_point(self, alg: ConvergenceDrawable, *args, **kwargs) -> Tensor:
        return alg.curr_point_to_draw

    def create_grid(self, alg) -> Tuple[Tensor, Tensor, Tensor]:
        raise NotImplementedError()


class FullFigConvergenceDrawer(ConvergenceDrawer):
    def __init__(
        self,
        *args,
        x_lower_bounds: float = -5,
        x_upper_bounds: float = 5,
        y_lower_bounds: float = -5,
        y_upper_bounds: float = 5,
        **kwargs
    ):
        super(FullFigConvergenceDrawer, self).__init__(*args, **kwargs)
        self.x_lower_bounds = x_lower_bounds
        self.x_upper_bounds = x_upper_bounds
        self.y_lower_bounds = y_lower_bounds
        self.y_upper_bounds = y_upper_bounds

    def create_grid(self, egl: ConvergenceDrawable) -> Tuple[Tensor, Tensor, Tensor]:
        points = create_grid_points(
            self.x_lower_bounds,
            self.x_upper_bounds,
            self.y_lower_bounds,
            self.y_upper_bounds,
            1000,
            self.dim_size,
            self.dims,
            egl.device
        )
        grid_values = egl.environment(points.detach(), debug_mode=True)
        return (
            points[:, 0].reshape(1000, 1000).detach(),
            points[:, 1].reshape(1000, 1000).detach(),
            grid_values,
        )


class StudiedFunctionConvergenceDrawer(ConvergenceDrawer):
    def __init__(
        self,
        *args,
        x_lower_bounds: float = -5,
        x_upper_bounds: float = 5,
        y_lower_bounds: float = -5,
        y_upper_bounds: float = 5,
        **kwargs
    ):
        super(StudiedFunctionConvergenceDrawer, self).__init__(*args, **kwargs)
        self.x_lower_bounds = x_lower_bounds
        self.x_upper_bounds = x_upper_bounds
        self.y_lower_bounds = y_lower_bounds
        self.y_upper_bounds = y_upper_bounds

    def create_grid(self, alg: ConvergenceAlgorithm) -> Tuple[Tensor, Tensor, Tensor]:
        points = create_grid_points(
            self.x_lower_bounds,
            self.x_upper_bounds,
            self.y_lower_bounds,
            self.y_upper_bounds,
            1000,
            self.dim_size,
            self.dims,
            device=alg.device,
        )
        grid_values = alg.helper_network(points.detach())
        return (
            points[:, 0].reshape(1000, 1000).cpu().detach(),
            points[:, 1].reshape(1000, 1000).cpu().detach(),
            grid_values.cpu().detach(),
        )


class GradNetVectorFieldEpochDrawer(StaticPLTDrawer):
    def __init__(
        self,
        lower_bounds: int = -5,
        upper_bounds: int = 5,
        dims: List[int] = None,
    ):
        super(GradNetVectorFieldEpochDrawer, self).__init__()
        self.dims = dims or [0, 1]
        self.lower_bounds = lower_bounds
        self.upper_bounds = upper_bounds

    def draw(self, alg, *args, **kwargs):
        x_axis = torch.linspace(self.lower_bounds, self.upper_bounds, 20)
        y_axis = torch.linspace(self.lower_bounds, self.upper_bounds, 20)
        x_grid, y_grid = torch.meshgrid(x_axis, y_axis)
        points = torch.stack((x_grid.reshape(-1), y_grid.reshape(-1)), dim=1)
        self.ax = self.fig.add_axes([0, 0, 1, 1])
        vectors = alg.grad_network(points)
        vectors = vectors / vectors.norm(dim=1).reshape(-1, 1)
        self.ax.quiver(
            x_grid,
            y_grid,
            vectors[:, 0].detach().reshape(20, 20),
            vectors[:, 1].detach().reshape(20, 20),
        )
        return [(self.fig, "")]


class BestModelDrawer(StaticPLTDrawer):
    def __init__(self, dims: List[int] = None, **kwargs):
        super(BestModelDrawer, self).__init__(**kwargs)
        self.dims = dims or [0, 1]

    def best_point(self, alg: ConvergenceDrawable, *args, **kwargs):
        return alg.best_point_until_now

    def draw(self, alg: ConvergenceDrawable, *args, **kwargs):
        params = self.best_point(alg).cpu().detach()
        self.ax.scatter(
            [params[self.dims[0]]],
            [params[self.dims[1]]],
            linestyle=":",
            marker="o",
            color="y",
            s=4,
        )
        return [(self.fig, "")]


class SamplesDrawer(StaticPLTDrawer):
    def __init__(self, *args, dims: List[int] = None, color: str = "g", **kwargs):
        super(SamplesDrawer, self).__init__(*args, **kwargs)
        self.color = color
        self.dims = dims or (0, 1)
        self.database_points: Optional[PathCollection] = None

    def draw(self, alg, *args, database: Tensor = None, **kwargs):
        if database is not None:
            if self.database_points:
                self.database_points.remove()
            database_to_print = database.cpu().detach()
            self.database_points = self.ax.scatter(
                database_to_print[:, self.dims[0]],
                database_to_print[:, self.dims[1]],
                marker="o",
                color=self.color,
                s=1,
            )
        return [(self.fig, "")]


class BestModelsDrawer(BestModelDrawer):
    def all_best_point(self, alg, idx: int, *args, **kwargs):
        return alg.models_to_train[idx].model_parameter_tensor()

    def draw(self, alg, *args, **kwargs):
        params = [self.all_best_point(alg, idx) for idx in range(len(alg.models_to_train))]
        for point in params:
            point = point.cpu().detach()
            self.ax.scatter(
                [point[self.dims[0]]],
                [point[self.dims[1]]],
                linestyle=":",
                marker="o",
                color="y",
                s=4,
            )
        return [(self.fig, "")]
