from typing import List

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from handlers.drawers.base_drawer import StaticPLTDrawer


class GradientDrawer(StaticPLTDrawer):
    def __init__(self, dims: List[int] = None, **kwargs):
        super(GradientDrawer, self).__init__(**kwargs)
        self.dims = dims or [0, 1]
        self.gradient = None

    def curr_point(self, alg, **kwargs):
        return alg.curr_point_to_draw.detach()

    def draw(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        curr = alg.curr_point_to_draw.detach()
        gradient = -alg.helper_network(curr).cpu().detach()
        gradient = gradient / gradient.norm()
        curr = self.curr_point(alg)[self.dims].cpu().detach()
        gradient = gradient[self.dims]

        self.gradient = self.ax.annotate(
            "", xy=curr + gradient, xytext=curr, arrowprops=dict(arrowstyle="->")
        )
        return [(self.fig, "")]
