import copy
from functools import partial
import glob
import math
import os
import sys

import numpy as np
import pickle5 as pickler
import scipy.linalg
import scipy.stats
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import tqdm
import yaml
from hyperopt import Trials, fmin, hp, tpe

import utils


def make_convex_quadratic(num_vars=500, seed=1, x0_std_dev=10, save_to_disk=True):
    """
    Generate a symmetric positive semidefinite matrix A with eigenvalues
    uniformly in [1e-3, 10].

    """

    torch.manual_seed(seed)
    np.random.seed(seed)

    # First generate an orthogonal matrix (of eigenvectors)
    eig_vecs = torch.tensor(
        scipy.stats.ortho_group.rvs(dim=(num_vars)), dtype=torch.double
    )
    # Now generate eigenvalues
    eig_vals = torch.rand(num_vars, dtype=torch.double) * 10 + 1e-3

    A = eig_vecs @ torch.diag(eig_vals) @ eig_vecs.T
    A_inv = eig_vecs @ torch.diag(1 / eig_vals) @ eig_vecs.T

    # Choose random starting point, also have optimal_x and optimal_val at 0
    x0 = torch.normal(0, x0_std_dev, size=(num_vars,), dtype=torch.double)
    optimal_x = torch.zeros((num_vars,), dtype=torch.double)
    optimal_val = 0.0

    problem = (x0, optimal_x, optimal_val, A, A_inv)

    # Save problem
    if save_to_disk:
        os.makedirs("experiments/function_experiments/quadratic", exist_ok=True)
        torch.save(problem, "experiments/function_experiments/quadratic/problem.pth")

    return problem


def quadratic(A, x):
    return 0.5 * x.T @ A @ x


def quadratic_conjugate_grad(A_inv, p):
    return A_inv @ p


def make_convex_quartic():
    num_vars = 25

    torch.manual_seed(1)
    np.random.seed(1)

    Qs = []
    for _ in range(2):
        eig_vecs = torch.tensor(
            scipy.stats.ortho_group.rvs(dim=(num_vars)), dtype=torch.double
        )
        eig_vals = torch.rand(num_vars, dtype=torch.double) * 3 + 1e-2
        Q = eig_vecs @ torch.diag(eig_vals) @ eig_vecs.T
        Qs.append(Q)

    x0 = torch.normal(0, 3, size=(num_vars,), dtype=torch.double)
    optimal_x = torch.zeros((num_vars,), dtype=torch.double)
    optimal_val = 0.0

    problem = (x0, optimal_x, optimal_val, Qs)
    os.makedirs("experiments/function_experiments/quartic", exist_ok=True)
    torch.save(problem, "experiments/function_experiments/quartic/problem.pth")


def quartic(Qs, x):
    return 0.25 * (x.T @ Qs[0] @ x) * (x.T @ Qs[1] @ x)


def quartic_hess(Qs, x):
    Q0_prod = Qs[0] @ x
    Q1_prod = Qs[1] @ x
    t1 = 0.5 * (x.T @ Q1_prod) * Qs[0] + 0.5 * (x.T @ Q0_prod) * Qs[1]
    t2 = torch.outer(Q0_prod, Q1_prod) + torch.outer(Q1_prod, Q0_prod)
    return t1 + t2


def rosenbrock(num_vars=100, optimal_at=1):
    # Initialization strategy: x_i = -2 if i is even, x_i = +2 if i is odd
    x0 = torch.tensor(
        [-2 if i % 2 == 0 else 2 for i in range(num_vars)], dtype=torch.double
    )
    x0 = x0 - 1 + optimal_at

    def obj_function(x):
        x = x - optimal_at + 1
        return torch.sum(100.0 * (x[1:] - x[:-1] ** 2.0) ** 2.0 + (1 - x[:-1]) ** 2.0)

    def grad(x):
        x = x - optimal_at + 1
        xm = x[1:-1]
        xm_m1 = x[:-2]
        xm_p1 = x[2:]
        grad = torch.zeros_like(x)
        grad[0] = -400 * x[0] * (x[1] - x[0] ** 2) - 2 * (1 - x[0])
        grad[1:-1] = (
            200 * (xm - xm_m1 ** 2) - 400 * (xm_p1 - xm ** 2) * xm - 2 * (1 - xm)
        )
        grad[-1] = 200 * (x[-1] - x[-2] ** 2)
        return grad

    def hess(x):
        x = x - optimal_at + 1
        H = torch.diag(-400 * x[:-1], 1) - torch.diag(400 * x[:-1], -1)
        diagonal = torch.zeros_like(x)
        diagonal[0] = 1200 * x[0] ** 2 - 400 * x[1] + 2
        diagonal[-1] = 200
        diagonal[1:-1] = 202 + 1200 * x[1:-1] ** 2 - 400 * x[2:]
        return H + torch.diag(diagonal)

    # Optimum at all x_i = optimal_at, giving f(x) = 0
    optimal_x = torch.zeros(num_vars) + optimal_at
    optimal_val = 0

    return x0, obj_function, optimal_x, optimal_val, grad, hess


def zakharov(initial_value=1):
    num_vars = 5
    x0 = torch.full((num_vars,), fill_value=initial_value, dtype=torch.double)

    def obj_function(x):
        sum_i = 0.5 * torch.sum(torch.arange(1, num_vars + 1) * x)
        return torch.sum(x ** 2) + sum_i ** 2 + sum_i ** 4

    def grad(x):
        i = torch.arange(1, num_vars + 1)
        sum_i = 0.5 * torch.sum(i * x)
        return 2 * x + (sum_i + 2 * sum_i ** 3) * i

    def hess(x):
        i = torch.arange(1, num_vars + 1, dtype=x.dtype)
        i_outer = 0.5 * torch.outer(i, i)
        eye = torch.eye(num_vars, dtype=x.dtype)
        return 2 * eye + i_outer + 6 * torch.sum(0.5 * i * x) ** 2 * i_outer

    optimal_x = torch.zeros(num_vars)
    optimal_val = 0

    return x0, obj_function, optimal_x, optimal_val, grad, hess


def hump_camel(humps):
    if humps == 3:
        x0 = torch.tensor([5, 5], dtype=torch.double)

        def obj_function(x):
            return (
                2 * x[0] ** 2
                - 1.05 * x[0] ** 4
                + x[0] ** 6 / 6
                + x[0] * x[1]
                + x[1] ** 2
            )

        def grad(x):
            return torch.tensor(
                [4 * x[0] - 4.2 * x[0] ** 3 + x[0] ** 5 + x[1], x[0] + 2 * x[1]],
                dtype=x.dtype,
            )

        def hess(x):
            return torch.tensor(
                [[4 - 12.6 * x[0] ** 2 + 5 * x[0] ** 4, 1], [1, 2]], dtype=x.dtype
            )

        optimal_x = torch.tensor([0, 0])
        optimal_value = 0

    if humps == 6:
        x0 = torch.tensor([3, 2], dtype=torch.double)

        def obj_function(x):
            return (
                (4 - 2.1 * x[0] ** 2 + x[0] ** 4 / 3) * x[0] ** 2
                + x[0] * x[1]
                + (4 * x[1] ** 2 - 4) * x[1] ** 2
                + 1.0316
            )

        grad, hess = None, None

        # two global optima
        optimal_x = (
            torch.tensor([0.0898, -0.7126], dtype=torch.double),
            torch.tensor([0.0898, -0.7126], dtype=torch.double),
        )
        optimal_value = 0

    return x0, obj_function, optimal_x, optimal_value, grad, hess


def phase_transition_function():
    def obj_function(x):
        if x <= 0.75:
            return x ** 2 + 1
        else:
            return (x - 2) ** 2

    def grad(x):
        if x <= 0.75:
            return 2 * x
        else:
            return 2 * (x - 2)

    def hess(x):
        return torch.tensor([[2]], dtype=x.dtype)

    def convex_conjugate(p):
        if p <= -0.5:
            return p ** 2 / 4 - 1
        else:
            return p ** 2 / 4 + 2 * p

    def convex_conjugate_grad(p):
        if p <= -0.5:
            return p / 2
        else:
            return p / 2 + 2

    x0 = torch.tensor([-1], dtype=torch.double)
    optimal_x = torch.tensor([2], dtype=x0.dtype)
    optimal_value = 0

    return (
        x0,
        obj_function,
        optimal_x,
        optimal_value,
        grad,
        hess,
        convex_conjugate,
        convex_conjugate_grad,
    )


def monomial(degree):
    def obj_function(x):
        return x.abs() ** degree

    def grad(x):
        return degree * x.sign() * x.abs() ** (degree - 1)

    def hess(x):
        print(degree * (degree - 1) * x.abs() ** (degree - 2))
        return degree * (degree - 1) * x.abs() ** (degree - 2)

    def convex_conjugate(p):
        b = degree / (degree - 1)
        return degree / b * p.abs() ** b

    def convex_conjugate_grad(p):
        b = degree / (degree - 1)
        return degree * p.sign() * p.abs() ** (b - 1)

    x0 = torch.tensor([-1], dtype=torch.double)
    optimal_x = torch.tensor([0], dtype=x0.dtype)
    optimal_value = 0

    return (
        x0,
        obj_function,
        optimal_x,
        optimal_value,
        grad,
        hess,
        convex_conjugate,
        convex_conjugate_grad,
    )


def noisy_parabola():
    twenty_pi = torch.tensor(20 * math.pi)

    def obj_function(x):
        return x ** 2 + 0.1 * x * torch.sin(twenty_pi * x)

    def hess(x):
        return (
            2
            + 0.1
            * twenty_pi
            * (2 * torch.cos(twenty_pi * x) - twenty_pi * torch.sin(twenty_pi * x))
        ).unsqueeze(0)

    x0 = torch.tensor([2.95], dtype=torch.double)
    optimal_x = torch.tensor([0], dtype=x0.dtype)
    optimal_value = 0

    return x0, obj_function, optimal_x, optimal_value, hess


def p_norm_regression():
    generator = torch.manual_seed(42)
    a = torch.randn((10, 10), generator=generator, dtype=torch.double)
    b = torch.tensor([0] * 5 + [1] * 5, dtype=torch.double)
    x0 = 4 * torch.randn(10, generator=generator, dtype=torch.double)

    def obj_function(x):
        return 0.25 * ((a @ x - b) ** 4).sum()

    optimal_value = 0
    optimal_x = None

    return x0, obj_function, optimal_x, optimal_value


def deep_regression(
    depth=6,
    h_size=10,
    init="gaussian",
    activation=None,
    balanced=False,
    nonregression=False,
    x0=None,
    x_and_y=None,
):
    generator = torch.manual_seed(42)
    if x_and_y:
        x, y = x_and_y
    elif nonregression:
        x = torch.ones((h_size, h_size), dtype=torch.double)
        y = torch.zeros(h_size, dtype=torch.double)
    elif depth == 2 and h_size == 1:
        x = torch.tensor([[1]], dtype=torch.double)
        y = torch.tensor([1], dtype=torch.double)
    else:
        x = torch.randn((h_size, h_size), generator=generator, dtype=torch.double)
        y = torch.tensor(
            [0] * (h_size // 2) + [1] * (h_size - h_size // 2), dtype=torch.double
        )

    w_depth = 1 if balanced else depth
    if x0 is not None:
        pass
    elif depth == 2 and h_size == 1:
        x0 = 5e2 * torch.tensor([[1, 1]], dtype=torch.double)
    elif h_size == 1:  # Scalar initialization
        x0 = 10 * torch.randn((1, w_depth), generator=generator, dtype=torch.double)
    elif init == "xavier":  # Pytorch's linear layer initialization
        x0 = torch.rand(
            (h_size, w_depth * h_size), generator=generator, dtype=torch.double
        )
        x0 = (2 * x0 - 1) / math.sqrt(h_size)
    elif init == "gaussian":
        x0 = torch.randn(
            (h_size, w_depth * h_size), generator=generator, dtype=torch.double
        )

    def obj_function(w):
        y_hat = x
        for layer in range(depth):
            if balanced:
                y_hat = w @ y_hat
            else:
                y_hat = w[:, h_size * layer : h_size * (layer + 1)] @ y_hat
            if activation is not None and layer < depth - 1:
                y_hat = activation(y_hat)
        return 0.5 * ((y_hat - y) ** 2).sum()

    return x0, obj_function


def mnist(depth=6, activation=None, arch="linear"):
    generator = torch.manual_seed(42)

    # Load MNIST dataset
    local_data_dir = glob.glob("/export/io*/data/sslocum/data")[0]
    train_ds = torchvision.datasets.MNIST(
        local_data_dir, train=True, transform=torchvision.transforms.ToTensor()
    )
    test_ds = torchvision.datasets.MNIST(
        local_data_dir, train=False, transform=torchvision.transforms.ToTensor()
    )
    train_x, train_y = zip(*[s for s in train_ds])
    test_x, test_y = zip(*[s for s in test_ds])
    train_data = torch.stack(train_x), torch.tensor(train_y)
    test_data = torch.stack(test_x), torch.tensor(test_y)

    if activation is None:
        activation = nn.Identity()

    if arch == "linear":
        model0 = nn.Sequential(
            nn.Linear(784, 512),
            activation,
            nn.Linear(512, 256),
            activation,
            nn.Linear(256, 256),
            activation,
            nn.Linear(256, 128),
            activation,
            nn.Linear(128, 64),
            activation,
            nn.Linear(64, 10),
        )

    def obj_function(model, data):
        x, y = data
        if arch == "linear":
            x = x.flatten(start_dim=1)
        return F.cross_entropy(F.softmax(model(x), dim=-1), y)

    return model0, obj_function, train_data, test_data


def flat_params_to_list(model, flat_params):
    param_list = []
    i = 0
    for p in model.parameters():
        param_list.append(flat_params[i : i + p.numel()].view(p.shape))
        i += p.numel()

    return param_list


def run_optimizer(
    make_optimizer,
    x0,
    obj_function,
    iterations,
    hyperparams,
    keep_traj=False,
    pbar=False,
):
    is_model = isinstance(x0, nn.Module)
    if is_model:
        x = copy.deepcopy(x0)
        optimizer = make_optimizer(x.parameters(), **hyperparams)
    else:
        x = x0.clone().requires_grad_()
        optimizer = make_optimizer([x], **hyperparams)

    # Minimize
    trajectory = [
        copy.deepcopy(x).requires_grad_(False)
        if isinstance(x, nn.Module)
        else x.detach().clone()
    ]
    values = []
    grads = []

    def closure():
        optimizer.zero_grad()
        obj_value = obj_function(x)
        obj_value.backward(create_graph=True)
        return obj_value

    iterator = tqdm.trange if pbar else range
    for _ in iterator(iterations):
        obj_value = optimizer.step(closure)
        if not torch.isfinite(obj_value):
            break
        values.append(obj_value.item())

        if keep_traj:
            if is_model:
                trajectory.append(copy.deepcopy(x).requires_grad_(False))
                grads.append([p.grad.detach().clone() for p in x.parameters()])
            else:
                trajectory.append(x.detach().clone())
                grads.append(x.grad.detach().clone())

    return [
        np.array(values),
        trajectory,
        grads,
        optimizer,
    ]


def tune_algos(
    x0,
    obj_function,
    algo_iters,
    num_samples,
    hyperparam_space,
    save_dir,
    algos,
    optimizer_specific_args={},
):
    for algo in tqdm.tqdm(algos):
        print(f"Tuning {algo} and saving results to {save_dir}...")
        os.makedirs(f"{save_dir}/{algo}", exist_ok=True)
        if os.path.exists(f"{save_dir}/{algo}/trials.hp"):
            os.remove(f"{save_dir}/{algo}/trials.hp")

        def experiment(hyperparams):
            args = optimizer_specific_args.get(algo, {})  # algo.split("-")[0], {})
            hyperparams = {**hyperparams, **args}
            vals, traj, grads, _ = run_optimizer(
                utils.make_optimizer(algo), x0, obj_function, algo_iters, hyperparams
            )
            return vals[-1]

        tpe_best = fmin(
            fn=experiment,
            space=utils.make_config(hyperparam_space, algo),
            algo=tpe.suggest,
            trials=Trials(),
            max_evals=num_samples,
            trials_save_file=f"{save_dir}/{algo}/trials.hp",
        )
        # Best hyperparams dict contains singleton tensors, so we call .item()
        tpe_best = {k: v.item() for k, v in tpe_best.items()}
        with open(f"{save_dir}/{algo}/hyperparams.yaml", "w") as f:
            yaml.dump(tpe_best, f)
        print()


def quadratic_experiment(i):
    save_dir = f"experiments/function_experiments/quadratic/run{i}"

    # Load problem
    problem = torch.load("experiments/function_experiments/quadratic/problem.pth")
    x0, optimal_x, optimal_value, A, A_inv = problem
    print(f"Objective function minimum: {optimal_value}")
    iterations = 800

    # Tune hyperparameters
    tune_algos(
        x0,
        partial(quadratic, A),
        iterations,
        num_samples=1600,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 2),
            "minus_momentum": hp.uniform("minus_momentum", 0, 0.8),
            "delta": hp.uniform("delta", 0, 15),
            "alpha": hp.uniform("alpha", 0, 1),
            "minus_beta1": hp.uniform("minus_beta1", 0, 0.2),
            "minus_beta2": hp.uniform("minus_beta2", 0, 0.2),
            "eps": hp.uniform("eps", 0, 1),
            "little_a": hp.uniform("little_a", 0.5, 5),
            "big_a": hp.uniform("big_a", 0.5, 5),
            "piecewise_at": hp.uniform("piecewise_at", 0, 20),
            "num_inner_loops": hp.quniform("num_inner_loops", 1, 5, 1),
            "alpha_newton": hp.uniform("alpha_newton", 0, 200),
            "centered": hp.choice("centered", [True, False]),
            "eta1": hp.uniform("eta1", 0, 1),
            "eta2": hp.uniform("eta2", 1, 2.5),
            "step_min": hp.uniform("step_min", 1e-8, 1e-2),
            "step_max": hp.uniform("step_max", 1e-1, 10),
            "newton_lr": hp.uniform("newton_lr", 0, 5),
        },
        save_dir=save_dir,
        algos=[
            # "gd",
            # "cm",
            # "nag",
            # "adam",
            # "adam3",
            # "adam4",
            # "adam_centered",
            # "amsgrad",
            # "rprop",
            # "rmsprop",
            # "adadelta",
            # "rgd_eu",
            # "rgd",
            # "pk",
            # "piecewise_pk",
            # "exact_ckd",
            # "ckd",
            # "ckd_hess",
            "md_exp",
        ],
        optimizer_specific_args={
            "exact_ckd": {"kinetic_grad": partial(quadratic_conjugate_grad, A_inv)},
            "ckd_hess": {"hessian": lambda x: A},
            "hhd": {"hessian": lambda x: A},
            "hhd_quad": {"hessian": lambda x: A},
            "newton": {"hessian": lambda x: A},
        },
    )


def quadratic_inner_loops_experiment(i):
    save_dir = f"experiments/quadratic_inner_loops/run{i}"

    # Load problem
    problem = torch.load("experiments/quadratic/problem.pth")
    x0, optimal_x, optimal_value, A, A_inv = problem
    print(f"Objective function minimum: {optimal_value}")
    iterations = 800

    # Tune hyperparameters
    algos = []
    optimizer_specific_args = {}
    for loops in range(1, 5):
        algo_name = f"ckd-num_inner_loops={loops}"
        algos.append(algo_name)
        # optimizer_specific_args[algo_name] = {"hessian": lambda x: A}

    tune_algos(
        x0,
        partial(quadratic, A),
        iterations,
        num_samples=600,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 2),
            "minus_momentum": hp.uniform("minus_momentum", 0, 0.8),
            "alpha_newton": hp.uniform("alpha_newton", 0, 200),
        },
        save_dir=save_dir,
        algos=algos,
        optimizer_specific_args=optimizer_specific_args,
    )


def quadratic_delta_experiment(i):
    # Load problem
    problem = torch.load("experiments/quadratic/problem.pth")
    x0, optimal_x, optimal_value, A = problem
    iterations = 800

    # Tune hyperparameters with different fixed values of delta
    for delta in [0, 1, 5, 10, 20]:
        tune_algos(
            x0,
            partial(quadratic, A),
            iterations,
            num_samples=400,
            hyperparam_space={
                "lr": hp.uniform("lr", 0, 1),
                "minus_momentum": hp.uniform("minus_momentum", 0, 0.5),
                "alpha": hp.uniform("alpha", 0, 1),
            },
            save_dir=f"experiments/quadratic_delta/run{i}",
            algos=[f"rgd_eu-delta={delta}", f"rgd-delta={delta}"],
        )


def quartic_experiment(i):
    save_dir = f"experiments/function_experiments/quartic/run{i}"

    # Load problem
    problem = torch.load("experiments/function_experiments/quartic/problem.pth")
    x0, optimal_x, optimal_value, Qs = problem
    hess = partial(quartic_hess, Qs)
    print(f"Objective function minimum: {optimal_value}")
    iterations = 400

    # Tune hyperparameters
    tune_algos(
        x0,
        partial(quartic, Qs),
        iterations,
        num_samples=1600,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 1.5),
            "minus_momentum": hp.uniform("minus_momentum", 0, 0.7),
            "delta": hp.uniform("delta", 0, 50),
            "alpha": hp.uniform("alpha", 0, 1),
            "minus_beta1": hp.uniform("minus_beta1", 0, 0.5),
            "minus_beta2": hp.uniform("minus_beta2", 0, 0.5),
            "eps": hp.uniform("eps", 0, 0.1),
            "little_a": hp.uniform("little_a", 1, 5),
            "big_a": hp.uniform("big_a", 0, 5),
            "piecewise_at": hp.uniform("piecewise_at", 0, 40),
            "newton_lr": hp.uniform("newton_lr", 0, 1.5),
        },
        save_dir=save_dir,
        algos=[
            # "gd",
            # "cm",
            # "nag",
            # "adam",
            # "adam3",
            # "adam4",
            # "rgd_eu",
            # "rgd",
            # "pk",
            # "piecewise_pk",
            "rgd_newton",
            "rgd_nomo",
        ],
        optimizer_specific_args={
            "hhd": {"hessian": hess},
            "rgd_newton": {"hessian": hess},
            "hhd_diag": {"hessian": hess},
            "newton": {"hessian": hess},
        },
    )


def rosenbrock_experiment(i):
    save_dir = f"experiments/function_experiments/rosenbrock_2/run{i}"

    # Define problem
    x0, obj_function, optimal_x, optimal_val, grad, hess = rosenbrock(num_vars=2)
    iterations = 2000

    # Tune hyperparameters
    tune_algos(
        x0,
        obj_function,
        iterations,
        num_samples=1000,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 0.1),
            "minus_momentum": hp.uniform("minus_momentum", 0, 0.2),
            "delta": hp.uniform("delta", 0, 30),
            "alpha": hp.uniform("alpha", 0, 1),
            "minus_beta1": hp.uniform("minus_beta1", 0, 0.2),
            "minus_beta2": hp.uniform("minus_beta2", 0, 0.1),
            "eps": hp.uniform("eps", 0, 1),
            "little_a": hp.uniform("little_a", 1, 2),
            "big_a": hp.uniform("big_a", 1, 5),
            "piecewise_at": hp.uniform("piecewise_at", 0, 20),
            "num_inner_loops": hp.quniform("num_inner_loops", 1, 5, 1),
            "alpha_newton": hp.uniform("alpha_newton", 0, 100),
            "hessian_eigenvalue_threshold": hp.uniform(
                "hessian_eigenvalue_threshold", 0, 20
            ),
            "newton_lr": hp.uniform("newton_lr", 0, 5),
            "max_iter": hp.quniform("max_iter", 1, 100, 1),
            **{f"c{i}": hp.uniform(f"c{i}", 0, 3) for i in range(1, 5)},
            "armijo_c": hp.uniform("armijo_c", 0, 1),
            "tau": hp.uniform("tau", 0, 1),
        },
        save_dir=save_dir,
        algos=[
            "backtracking_gd",
            # "gd",
            # "cm",
            # "rgd_eu",
            # "rgd",
            # "pk-big_a=1.333|little_a=2",
            # "nag",
            # "adam3",
            # "adam",
        ],
        optimizer_specific_args={
            "ckd_hess": {"hessian": hess},
            "good_approx_ckd": {"hessian": hess},
            "ckd_hess_pd_threshold": {"hessian": hess},
            "ckd_hess_pd_additive": {"hessian": hess},
            "hhd": {"hessian": hess},
            "rgd_newton": {"hessian": hess},
            "hhd_diag": {"hessian": hess},
            "hhd_diag_corrected": {"hessian": hess},
            "ckvd": {"hessian": hess},
            "newton": {"hessian": hess},
        },
    )


def zakharov_experiment(i):
    save_dir = f"experiments/function_experiments/zakharov_1/run{i}"

    x0, obj_function, optimal_x, optimal_val, grad, hess = zakharov(initial_value=1)
    iterations = 200

    tune_algos(
        x0,
        obj_function,
        iterations,
        num_samples=1600,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 0.1),
            "minus_momentum": hp.uniform("minus_momentum", 0, 1),
            "delta": hp.uniform("delta", 0, 30),
            "alpha": hp.uniform("alpha", 0, 1),
            "minus_beta1": hp.uniform("minus_beta1", 0, 1),
            "minus_beta2": hp.uniform("minus_beta2", 0, 1),
            "eps": hp.uniform("eps", 0, 10),
            "little_a": hp.uniform("little_a", 2, 5),
            "big_a": hp.uniform("big_a", 0.5, 5),
            "num_inner_loops": hp.quniform("num_inner_loops", 1, 2, 1),
            "alpha_newton": hp.uniform("alpha_newton", 0, 200),
            "newton_lr": hp.uniform("newton_lr", 0, 5),
            "max_iter": hp.quniform("max_iter", 1, 100, 1),
            "hessian_eigenvalue_threshold": hp.uniform(
                "hessian_eigenvalue_threshold", 0, 20
            ),
            "conj_hessian_lipschitz": hp.loguniform("conj_hessian_lipschitz", 1e-10, 1),
            **{f"c{i}": hp.uniform(f"c{i}", 0, 3) for i in range(1, 8)},
            "armijo_c": hp.uniform("armijo_c", 0, 1),
            "tau": hp.uniform("tau", 0, 1),
        },
        save_dir=save_dir,
        algos=[
            "backtracking_gd",
            # "mcd", "pk-minus_momentum=1|big_a=1.333|little_a=2",
        ],  # ["gd", "cm", "nag", "adam", "adam3", "adam4", "rgd_eu", "rgd", "pk", "ckd", "ckd_hess", "good_approx_ckd"],
        optimizer_specific_args={
            "ckd_hess": {"hessian": hess},
            "good_approx_ckd": {"hessian": hess},
            "hhd": {"hessian": hess},
            "rgd_newton": {"hessian": hess},
            "hhd_upper_bounded": {"hessian": hess},
            "hhd_critically_damped": {"hessian": hess},
            "hhd_diag": {"hessian": hess},
            "hhd_diag_corrected": {"hessian": hess},
            "ckvd": {"hessian": hess},
            "newton": {"hessian": hess},
            "newton_2nd_momentum": {"hessian": hess},
            "mcd": {"model_der": lambda x: x},
        },
    )


def hump_camel_experiment(i, humps):
    if humps == 3:
        save_dir = f"experiments/function_experiments/three_hump_camel/run{i}"
    elif humps == 6:
        save_dir = f"experiments/function_experiments/six_hump_camel/run{i}"

    x0, obj_function, optimal_x, optimal_val, grad, hess = hump_camel(humps)
    iterations = 1000

    tune_algos(
        x0,
        obj_function,
        iterations,
        num_samples=800,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 1e-1),
            "minus_momentum": hp.uniform("minus_momentum", 0, 1),
            "delta": hp.uniform("delta", 0, 30),
            "alpha": hp.uniform("alpha", 0, 1),
            "minus_beta1": hp.uniform("minus_beta1", 0, 0.5),
            "minus_beta2": hp.uniform("minus_beta2", 0, 0.8),
            "eps": hp.uniform("eps", 0, 10),
            "little_a": hp.uniform("little_a", 1, 5),
            "big_a": hp.uniform("big_a", 1, 5),
            "piecewise_at": hp.uniform("piecewise_at", 0, 20),
            "num_inner_loops": hp.quniform("num_inner_loops", 1, 5, 1),
            "alpha_newton": hp.uniform("alpha_newton", 0, 100),
            "hessian_eigenvalue_threshold": hp.uniform(
                "hessian_eigenvalue_threshold", 0, 20
            ),
        },
        save_dir=save_dir,
        algos=["rgd_nomo", "gd"],
        optimizer_specific_args={
            "ckd_hess": {"hessian": hess},
            "good_approx_ckd": {"hessian": hess},
            "ckd_hess_pd_threshold": {"hessian": hess},
            "ckd_hess_pd_additive": {"hessian": hess},
        },
    )


def p_norm_regression_experiment(i):
    save_dir = f"experiments/function_experiments/p_norm_regression/run{i}"

    x0, obj_function, optimal_x, optimal_value = p_norm_regression()
    iterations = 1000

    tune_algos(
        x0,
        obj_function,
        iterations,
        num_samples=100,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 0.25),
            "delta": hp.uniform("delta", 0, 30),
            "armijo_c": hp.uniform("armijo_c", 0, 1),
            "tau": hp.uniform("tau", 0, 1),
        },
        save_dir=save_dir,
        algos=[
            "backtracking_gd",
        ],  # "gd", "rgd_nomo", "pk-minus_momentum=1|big_a=1.333|little_a=1.333"],
    )


def deep_regression_experiment(i):
    save_dir = f"results/deep_linear_simple_regression/run{i}"

    x0, obj_function = deep_regression(depth=6, h_size=10, nonregression=True)
    iterations = 1000

    def prod_k_grad(p):
        return p.prod().abs() ** (2 / (2 * p.numel() - 1)) / p

    tune_algos(
        x0,
        obj_function,
        iterations,
        num_samples=600,
        hyperparam_space={
            "lr": hp.uniform("lr", 0, 0.1),
            "delta": hp.uniform("delta", 0, 30),
            #"armijo_c": hp.uniform("armijo_c", 0, 2),
            #"tau": hp.uniform("tau", 0, 1),
        },
        save_dir=save_dir,
        algos=[
            #"gd",
            "rgd_nomo",
            "pk-minus_momentum=1|big_a=1.0909|little_a=2",
            "pk-minus_momentum=1|big_a=1.0909|little_a=1.0909",
            # "exact_ckd-minus_momentum=1",
            # "backtracking_gd",
        ],
        optimizer_specific_args={
            "exact_ckd-minus_momentum=1": {"kinetic_grad": prod_k_grad}
        },
    )


def load_hyperparams(save_dirs, first_n_trials=None):
    """Return a list of dicts containing best hyperparams per experiment for each algo."""
    hyperparams = []
    for save_dir in save_dirs:
        experiment_hps = {}

        for algo in os.listdir(save_dir):
            if os.path.isfile(f"{save_dir}/{algo}") or not os.path.exists(
                f"{save_dir}/{algo}/trials.hp"
            ):
                continue

            with open(f"{save_dir}/{algo}/trials.hp", "rb") as f:
                trials = pickler.load(f).trials

            if first_n_trials is not None:
                trials = trials[:first_n_trials]

            best_trial = min(trials, key=lambda t: t["result"]["loss"])
            experiment_hps[algo] = {
                name: val[0] for name, val in best_trial["misc"]["vals"].items()
            }

        hyperparams.append(experiment_hps)

    return hyperparams


def main():
    experiment = sys.argv[1]
    if experiment == "make_quadratic":
        make_convex_quadratic()
    elif experiment == "make_quartic":
        make_convex_quartic()

    run_number = sys.argv[2]
    if experiment == "quadratic":
        quadratic_experiment(run_number)
    elif experiment == "quadratic_inner_loops":
        quadratic_inner_loops_experiment(run_number)
    elif experiment == "quartic":
        quartic_experiment(run_number)
    elif experiment == "rosenbrock":
        rosenbrock_experiment(run_number)
    elif experiment == "zakharov":
        zakharov_experiment(run_number)
    elif experiment == "three_hump_camel":
        hump_camel_experiment(run_number, humps=3)
    elif experiment == "six_hump_camel":
        hump_camel_experiment(run_number, humps=6)
    elif experiment == "quadratic_delta":
        quadratic_delta_experiment(run_number)
    elif experiment == "p_norm_regression":
        p_norm_regression_experiment(run_number)
    elif experiment == "deep_regression":
        deep_regression_experiment(run_number)
    else:
        raise ValueError("Invalid experiment name.")


if __name__ == "__main__":
    main()
