import torch
from pandas import DataFrame

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.utils import ball_perturb
from algorithms.space.tensor_space import TensorSpace
from compute_result.result_store.base import ResultStore
from handlers.base_handler import EpochEndCallbackHandler


class GradientErrorMetricHandler(EpochEndCallbackHandler):
    def __init__(self, result_store: ResultStore, num_test: int, run_name: str):
        self.result_store = result_store
        self.num_test = num_test
        self.run_name = run_name
        self.curr_grad_error = []
        self.mean_grad_error = []

    def on_epoch_end(self, egl: EGL, *args, **kwargs):
        env = egl.env
        if not isinstance(env, TensorSpace):
            return
        gradient = egl.curr_gradient()
        true_gradient = env.g_func(egl.curr_point_to_draw)
        true_gradient[true_gradient != true_gradient] = 0
        main_grad_sim = torch.nn.functional.cosine_similarity(
            gradient.unsqueeze(0), true_gradient.unsqueeze(0)
        ).mean()

        test_sample = ball_perturb(
            egl.curr_point_to_draw, egl.epsilon, self.num_test, device=egl.device
        )
        true_test_gradient = env.g_func(test_sample)
        true_test_gradient[true_test_gradient != true_test_gradient] = 0
        test_gradient = egl.helper_network(test_sample).detach()
        test_grad_sim = torch.nn.functional.cosine_similarity(
            test_gradient, true_test_gradient
        ).mean()
        self.curr_grad_error.append((egl.env.used_budget, main_grad_sim.mean().item()))
        self.mean_grad_error.append((egl.env.used_budget, test_grad_sim.item()))

    def on_algorithm_end(self, alg, *args, **kwargs):
        if not isinstance(alg.env, TensorSpace):
            return

        mean_grad_error = DataFrame(self.mean_grad_error)
        curr_grad_error = DataFrame(self.curr_grad_error)
        self.result_store.store_metric(
            f"grad_error_{repr(alg.env)}_{alg.ALGORITHM_NAME}_{self.run_name}", curr_grad_error
        )
        self.result_store.store_metric(
            f"mean_grad_error_{repr(alg.env)}_{alg.ALGORITHM_NAME}_{self.run_name}",
            mean_grad_error,
        )


class GradientMetric(EpochEndCallbackHandler):
    def __init__(self, result_store: ResultStore, run_name: str):
        self.result_store = result_store
        self.run_name = run_name
        self.grads = []

    def on_epoch_end(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        curr_point = alg.curr_point_to_draw
        if alg.input_mapping is not None:
            curr_point = alg.env.denormalize(alg.input_mapping.inverse(curr_point))
        gradient = alg.curr_gradient()
        self.grads.append(
            (alg.env.used_budget, str(gradient.tolist()), str(curr_point.tolist()))
        )

    def on_algorithm_end(self, alg: ConvergenceAlgorithm, *args, **kwargs):
        grad_size = DataFrame(self.grads)
        self.result_store.store_metric(
            f"curr_grad_{alg.ALGORITHM_NAME}_{self.run_name}_{repr(alg.env)}", grad_size
        )
