import torch

from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.utils import ball_perturb
from handlers.drawers.base_drawer import StaticPLTDrawer


class HelperNetworkLossDrawer(StaticPLTDrawer):
    def draw(self, alg: EGL, *args, **kwargs):
        x = alg.curr_point_to_draw
        samples = ball_perturb(x, alg.epsilon, 100, alg.device)
        values = alg.env(samples, debug_mode=True)
        mapped_evaluations = alg.output_mapping.map(values)
        shuffled_points_idx = torch.randperm(len(samples))
        x_i = samples
        x_j = samples[shuffled_points_idx]
        y_i = mapped_evaluations
        y_j = mapped_evaluations[shuffled_points_idx]

        grad_on_perturb = alg.helper_network(samples)

        value = ((x_j - x_i) * grad_on_perturb).sum(dim=1)
        target = y_j - y_i

        loss = alg.grad_loss(value, target)
        return [(loss.item(), "grad_loss")]


class GradientLossDrawer(StaticPLTDrawer):
    def draw(self, alg: EGL, *args, **kwargs):
        x = ball_perturb(alg.curr_point_to_draw, alg.epsilon, 20, alg.device)
        values = alg.env(x, debug_mode=True)
        values = alg.output_mapping.map(values)
        grad_on_perturb = alg.grad_network(x)

        shuffled_idx = torch.randperm(len(x))
        x_i, x_j = x, x[shuffled_idx]
        y_i, y_j = values, values[shuffled_idx]

        value = x_j - x_i
        target = y_j - y_i

        real_grad, _, _, _ = torch.linalg.lstsq(value, target, rcond=None)
        grad_loss = ((real_grad - grad_on_perturb).norm(dim=1)).mean()
        return [
            (grad_loss.item(), "grad distance from real"),
            (value.sum(dim=1).mean().item(), "x distance"),
            (target.mean().item(), "value difference"),
        ]


class GradSizeDrawer(StaticPLTDrawer):
    def draw(self, alg: EGL, *args, **kwargs):
        x = alg.curr_point_to_draw
        grad_on_perturb = alg.grad_network(x)

        return [(grad_on_perturb.norm().sqrt().item(), "grad size")]


class StepSizeDrawer(StaticPLTDrawer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_point = None

    def get_last_point(self, *args, **kwargs):
        return self.last_point

    def get_curr_point(self, alg, *args, **kwargs):
        return alg.curr_point_to_draw

    def draw(self, alg: EGL, *args, **kwargs):
        if self.last_point is None:
            self.last_point = alg.curr_point_to_draw
            return []
        last_point = self.get_last_point(alg)
        curr_point = self.get_curr_point(alg)

        gradient = alg.grad_network(alg.curr_point_to_draw)
        return [
            ((curr_point - last_point).norm().sqrt().item(), "step size"),
            (curr_point, "curr_point_real"),
            (alg.curr_point_to_draw, "curr_point"),
            (gradient.detach(), "grad"),
        ]
