from typing import List

import torch
from matplotlib import pyplot as plt

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.mapping.trust_region import StaticShrinkingTrustRegion
from handlers.drawers.base_drawer import StaticPLTDrawer
from handlers.drawers.utils import pad_2d_to_nd_graph


class BallPerturbDrawer(StaticPLTDrawer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.circle = None

    def draw(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        if self.circle:
            self.circle.remove()
        self.circle = plt.Circle(
            alg.model_to_train.model_parameter_tensor().cpu().detach(),
            alg.epsilon,
            color="r",
            fill=False,
        )
        self.ax.add_patch(self.circle)


class TrustRegionDrawer(StaticPLTDrawer):
    def __init__(self, dims: List[int] = None, dim_size: int = 2, **kwargs):
        super(TrustRegionDrawer, self).__init__(**kwargs)
        self.dims = dims or [0, 1]
        self.dim_size = dim_size

    def center_point(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        x = torch.tensor(
            [
                (
                    alg.environment.upper_bound[self.dims[0]]
                    + alg.environment.lower_bound[self.dims[0]]
                )
                / 2
            ],
            device=alg.device,
        )
        y = torch.tensor(
            [
                (
                    alg.environment.upper_bound[self.dims[1]]
                    + alg.environment.lower_bound[self.dims[1]]
                )
                / 2
            ],
            device=alg.device,
        )

        return pad_2d_to_nd_graph(x, y, tuple(self.dims), self.dim_size).squeeze()

    def points(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        x_points = torch.tensor(
            [
                alg.environment.upper_bound[self.dims[0]],
                alg.environment.upper_bound[self.dims[0]],
                alg.environment.lower_bound[self.dims[0]],
                alg.environment.lower_bound[self.dims[0]],
                alg.environment.upper_bound[self.dims[0]],
            ],
            device=alg.device,
        )
        y_points = torch.tensor(
            [
                alg.environment.upper_bound[self.dims[1]],
                alg.environment.lower_bound[self.dims[1]],
                alg.environment.lower_bound[self.dims[1]],
                alg.environment.upper_bound[self.dims[1]],
                alg.environment.upper_bound[self.dims[1]],
            ],
            device=alg.device,
        )
        points = pad_2d_to_nd_graph(x_points, y_points, tuple(self.dims), self.dim_size)
        return points

    def draw(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        if not (
            alg.input_mapping is not None
            and isinstance(alg.input_mapping, StaticShrinkingTrustRegion)
        ):
            return []

        center_point_rectangle = self.center_point(alg).cpu().detach()
        points = self.points(alg).cpu().detach()
        self.ax.plot(points[:, 0], points[:, 1], color="r")
        self.ax.plot(
            [center_point_rectangle[0]], [center_point_rectangle[1]], color="b", marker="o"
        )
        return [(self.fig, "")]
