import click
import matplotlib.pyplot as plt
import torch
from torch.nn import SmoothL1Loss
from torch.optim import Adam

from algorithms.convergence_algorithms.utils import ball_perturb, reset_all_weights
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.distributions import SigmoidWeights
from algorithms.nn.losses import (
    NaturalHessianLoss,
    GradientLoss,
    loss_with_quantile,
    HessianWithDifferentNetwork,
)
from algorithms.nn.modules import BasicNetwork, BasicHessianNetwork, MultipleOptimizer
from algorithms.nn.trainer import train_gradient_network
from run_options import space_instance

plt.style.use("ggplot")
dtype = torch.float64


def compute_accuracy(
    grad,
    env,
    loss,
    optimizer,
    tuples_trained,
    batch_size,
    center_points,
    epsilon,
    num_samples,
    device,
):
    total_cos_error = []
    for point in center_points:
        reset_all_weights(grad)
        samples = ball_perturb(point, epsilon, num_samples, dtype, device=device)
        values = env(samples)
        dataset = PairsInEpsRangeDataset(samples, values, epsilon, tuples_trained)
        train_gradient_network(loss, optimizer, dataset, batch_size, None)

        gradient = grad(point)
        true_gradient = env.g_func(point)
        true_gradient[true_gradient != true_gradient] = 0
        main_grad_sim = torch.nn.functional.cosine_similarity(
            gradient.unsqueeze(0), true_gradient.unsqueeze(0)
        ).mean()
        total_cos_error.append(main_grad_sim.item())
    return sum(total_cos_error) / len(center_points)


def accuracy_based_epsilon(
    grad,
    env,
    loss,
    optimizer,
    tuples_trained,
    batch_size,
    center_points,
    epsilon_max,
    epsilon_min,
    steps,
    num_samples,
    device,
):
    acc = []
    linear_steps = torch.linspace(0, 1, steps=steps)
    exp_steps = torch.exp(linear_steps) - 1  # Exponential growth
    exp_steps = exp_steps / exp_steps.max() * (epsilon_max - epsilon_min) + epsilon_min
    for epsilon in exp_steps:
        reset_all_weights(grad)
        accuracy = compute_accuracy(
            grad,
            env,
            loss,
            optimizer,
            tuples_trained,
            batch_size,
            center_points,
            epsilon,
            num_samples,
            device,
        )
        acc += [(epsilon, accuracy)]
    return acc


def accuracy_for_algs(
    env,
    algs,
    center_points,
    tuples,
    batch_size,
    epsilon_max,
    epsilon_min,
    steps,
    samples,
    device,
):
    for alg_name, alg in algs.items():
        print(f"Computing accuracy for {alg_name}")
        acc = accuracy_based_epsilon(
            alg["grad"],
            env,
            alg["loss"],
            alg["optimizer"],
            tuples,
            batch_size,
            center_points,
            epsilon_max,
            epsilon_min,
            steps,
            samples,
            device,
        )
        accuracy = [a[1] for a in acc]
        epsilons = [a[0] for a in acc]
        plt.plot(epsilons, accuracy, label=alg_name)
    plt.legend()
    # plt.yscale("log")
    plt.tight_layout()
    plt.show()


@click.command
@click.option(
    "--device",
    "device",
    required=True,
    help="Device to run the computations.",
    default=0 if torch.cuda.is_available() else "cpu",
)
@click.option("--tuples", "tuples", required=True, help="Tuples.", default=1000)
@click.option(
    "--batch_size", "batch_size", required=True, help="Batch size.", default=360
)
@click.option(
    "--epsilon_max", "epsilon_max", required=True, help="Epsilon max.", default=0.6
)
@click.option(
    "--epsilon_min", "epsilon_min", required=True, help="Epsilon min.", default=0.01
)
@click.option("--steps", "steps", required=True, help="Steps.", default=50)
@click.option("--samples", "samples", required=True, help="Samples.", default=200)
@space_instance
def main(device, tuples, batch_size, epsilon_max, epsilon_min, steps, samples, space, **kwargs):
    env = space
    dim = space.dimension
    hegl_grad_net = BasicNetwork(dim, device, dtype)
    egl_grad_net = BasicNetwork(dim, device, dtype)
    hnet_grad_net = BasicNetwork(dim, device, dtype)
    hessian_net = BasicHessianNetwork(dim, device, dtype)
    algs = {
        "H-EGL": {
            "grad": hegl_grad_net,
            "optimizer": Adam(hegl_grad_net.parameters(), eps=1e-04),
            "loss": NaturalHessianLoss(hegl_grad_net, 0, SmoothL1Loss()),
        },
        "EGL": {
            "grad": egl_grad_net,
            "optimizer": Adam(egl_grad_net.parameters(), eps=1e-04),
            "loss": GradientLoss(egl_grad_net, 0, SmoothL1Loss()),
        },
        "HNet-EGL": {
            "grad": hnet_grad_net,
            "optimizer": MultipleOptimizer(
                Adam(hnet_grad_net.parameters(), eps=1e-04), Adam(hessian_net.parameters(), eps=1e-04)
            ),
            "loss": HessianWithDifferentNetwork(
                hnet_grad_net, 0, SmoothL1Loss(), hessian_network=hessian_net
            ),
        },
    }

    center_points = (torch.rand(10, dim, device=device, dtype=dtype) - 0.5) * 10
    accuracy_for_algs(
        env,
        algs,
        center_points,
        tuples,
        batch_size,
        epsilon_max,
        epsilon_min,
        steps,
        samples,
        device,
    )


if __name__ == "__main__":
    main()
