import math
import os

import GPUtil
import torch

import rgd


def make_sgd(params, lr):
    return torch.optim.SGD(params, lr)


def make_cm(params, lr, minus_momentum):
    return torch.optim.SGD(params, lr, momentum=1 - minus_momentum)


def make_nag(params, lr, minus_momentum):
    return torch.optim.SGD(params, lr, momentum=1 - minus_momentum, nesterov=True)


def make_adam(params, lr, minus_beta1=0.9, minus_beta2=0.999, eps=1e-8):
    return torch.optim.Adam(params, lr)


def make_adam3(params, lr, minus_beta1, minus_beta2):
    return torch.optim.Adam(params, lr, (1 - minus_beta1, 1 - minus_beta2))


def make_adam4(params, lr, minus_beta1, minus_beta2, eps):
    return torch.optim.Adam(params, lr, (1 - minus_beta1, 1 - minus_beta2), eps)


def make_rprop(params, lr, eta1, eta2, step_min, step_max):
    return torch.optim.Rprop(
        params, lr, etas=(eta1, eta2), step_sizes=(step_min, step_max)
    )


def make_rmsprop(params, lr, minus_beta2, centered):
    return torch.optim.RMSprop(params, lr, 1 - minus_beta2, centered=centered)


def make_adadelta(params, lr, minus_beta1):
    return torch.optim.Adadelta(params, lr, 1 - minus_beta1)


def make_adam_centered(params, lr, minus_beta1, minus_beta2):
    return rgd.CustomAdam(params, lr, (1 - minus_beta1, 1 - minus_beta2), centered=True)


def make_amsgrad(params, lr, minus_beta1, minus_beta2):
    return torch.optim.Adam(
        params, lr, (1 - minus_beta1, 1 - minus_beta2), amsgrad=True
    )


def make_rgd_eu(params, lr, minus_momentum, delta):
    return rgd.RGD(params, lr, 1 - minus_momentum, delta, integrator="symplectic_euler")


def make_rgd(params, lr, minus_momentum, delta, alpha):
    return rgd.RGD(params, lr, 1 - minus_momentum, delta, alpha=alpha)


def make_power_kinetic(params, lr, minus_momentum, delta, little_a, big_a):
    return rgd.PowerKinetic(
        params,
        lr,
        1 - minus_momentum,
        delta,
        little_a,
        big_a,
    )


def make_piecewise_pk(params, lr, minus_momentum, delta, little_a, big_a, piecewise_at):
    return rgd.PowerKinetic(
        params, lr, 1 - minus_momentum, delta, little_a, big_a, piecewise_at
    )


def make_mcd(params, lr, minus_momentum, model_der):
    return rgd.ModelConjugateDescent(params, lr, 1 - minus_momentum, model_der)


def make_poly_kd(params, lr, minus_momentum, **kwargs):
    return rgd.PowerKinetic(
        params, lr, 1 - minus_momentum, poly_coefficients=list(kwargs.values())
    )


def make_exact_ckd(params, lr, minus_momentum, kinetic_grad):
    return rgd.ConjugateKineticDescent(
        params, lr, 1 - minus_momentum, method="exact", kinetic_grad=kinetic_grad
    )


def make_good_approx_ckd(params, lr, minus_momentum, alpha_newton, hessian):
    return rgd.GoodApproxConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        num_inner_loops=20,
        alpha=alpha_newton,
        hessian=hessian,
    )


def make_ckd(params, lr, minus_momentum, num_inner_loops):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="grad_approx",
        num_inner_loops=num_inner_loops,
    )


def make_ckd_hinit(params, lr, minus_momentum, num_inner_loops):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="grad_approx",
        num_inner_loops=num_inner_loops,
        initialization="heuristic",
    )


def make_ckd_hess(params, lr, minus_momentum, num_inner_loops, alpha_newton, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="hess_approx",
        num_inner_loops=num_inner_loops,
        alpha=alpha_newton,
        hessian=hessian,
    )


def make_ckd_hess_hinit(params, lr, minus_momentum, alpha_newton, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="hess_approx",
        alpha=alpha_newton,
        hessian=hessian,
        initialization="heuristic",
    )


def make_ckd_hess_pd_threshold(
    params,
    lr,
    minus_momentum,
    alpha_newton,
    hessian_eigenvalue_threshold,
    hessian,
):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="hess_approx",
        alpha=alpha_newton,
        hessian=hessian,
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hess_approximation="threshold",
    )


def make_hhd_pd(
    params,
    lr,
    minus_momentum,
    hessian_eigenvalue_threshold,
    hessian,
    track_iterates=False,
):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hess_approximation="threshold",
        hessian=hessian,
        track_iterates=track_iterates,
    )


def make_hhd(params, lr, minus_momentum, hessian, track_iterates=False):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hessian=hessian,
        track_iterates=track_iterates,
    )


def make_hhd_upper_bounded(params, lr, minus_momentum, conj_hessian_lipschitz, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="upper_bounded",
        conj_hessian_lipschitz=conj_hessian_lipschitz,
        hessian=hessian,
    )


def make_rgd_newton(params, lr, delta, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        momentum=0,
        method="rgd_hess",
        delta=delta,
        hessian=hessian,
    )


def make_rgd_nomo(params, lr, delta):
    return rgd.RGD(
        params,
        lr,
        momentum=0,
        delta=delta,
    )


def make_rgd_newton_pd(params, lr, delta, hessian_eigenvalue_threshold, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        0,
        method="rgd_hess",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hess_approximation="threshold",
        delta=delta,
        hessian=hessian,
    )


def make_hhd_critically_damped(params, newton_lr, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        newton_lr,
        math.exp(-2 * newton_lr),
        method="inv_hess",
        initialization="heuristic",
        hessian=hessian,
    )


def make_hhd_abs_pd(params, lr, minus_momentum, hessian_eigenvalue_threshold, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hessian=hessian,
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hess_approximation="threshold_abs_pd",
    )


def make_hhd_diag(params, lr, minus_momentum, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hess_approximation="diagonal",
        hessian=hessian,
    )


def make_hhd_diag_corrected(params, lr, minus_momentum, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hess_approximation="diagonal_corrected",
        hessian=hessian,
    )


def make_hhd_diag_pd(params, lr, minus_momentum, hessian_eigenvalue_threshold, hessian):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hess_approximation="diagonal_threshold",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hessian=hessian,
    )


def make_hhd_diag_abs_pd(
    params, lr, minus_momentum, hessian_eigenvalue_threshold, hessian
):
    return rgd.ConjugateKineticDescent(
        params,
        lr,
        1 - minus_momentum,
        method="inv_hess",
        hess_approximation="diagonal_abs_pd",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hessian=hessian,
    )


def make_ckvd(params, lr, minus_momentum, hessian):
    return rgd.ConjugateKineticDescent(
        params, lr, 1 - minus_momentum, method="velocity", hessian=hessian
    )


def make_newton(params, newton_lr, hessian):
    return rgd.Newton(params, lr=newton_lr, hessian=hessian)


def make_newton_2nd_momentum(
    params, newton_lr, minus_momentum, hessian_eigenvalue_threshold, hessian
):
    return rgd.Newton(
        params,
        lr=newton_lr,
        hess_momentum=1 - minus_momentum,
        hess_approximation="threshold",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
        hessian=hessian,
    )


def make_newton_step_avg(params, newton_lr, minus_beta1, hessian):
    return rgd.Newton(
        params,
        lr=newton_lr,
        step_beta=1 - minus_beta1,
        hessian=hessian,
    )


def make_newton_pd(params, newton_lr, hessian_eigenvalue_threshold, hessian):
    return rgd.Newton(
        params,
        newton_lr,
        hessian=hessian,
        hess_approximation="threshold",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_newton_abs_pd(params, newton_lr, hessian_eigenvalue_threshold, hessian):
    return rgd.Newton(
        params,
        newton_lr,
        hessian=hessian,
        hess_approximation="threshold_abs_pd",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_lbfgs(params, newton_lr, max_iter, history_size=None):
    return torch.optim.LBFGS(params, newton_lr, int(max_iter), history_size=3)


def make_lbfgs_wolfe(params, newton_lr, max_iter, history_size=None):
    return torch.optim.LBFGS(
        params,
        newton_lr,
        int(max_iter),
        history_size=3,
        line_search_fn="strong_wolfe",
    )


def make_adahessian(params, newton_lr):
    return rgd.Adahessian(params, newton_lr)


def make_adahessian3(params, newton_lr, minus_beta1, minus_beta2):
    return rgd.Adahessian(params, newton_lr, betas=(1 - minus_beta1, 1 - minus_beta2))


def make_adahessian4(params, newton_lr, minus_beta1, minus_beta2, eps):
    return rgd.Adahessian(params, newton_lr, (1 - minus_beta1, 1 - minus_beta2), eps)


def make_hhd_hutchinson(
    params, newton_lr, minus_momentum, hessian_eigenvalue_threshold
):
    return rgd.Adahessian(
        params,
        newton_lr,
        method="hhd",
        betas=(1 - minus_momentum, 0.999),
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_rgd_newton_hutchinson(params, newton_lr, delta, hessian_eigenvalue_threshold):
    return rgd.Adahessian(
        params,
        newton_lr,
        betas=(0, 0.999),
        method="hhd",
        delta=delta,
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_rgd_newton1_hutchinson(
    params, newton_lr, minus_beta1, delta, hessian_eigenvalue_threshold
):
    return rgd.Adahessian(
        params,
        newton_lr,
        betas=(1 - minus_beta1, 0.999),
        method="hhd",
        delta=delta,
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_hhd_hutchinson_corrected(
    params, newton_lr, minus_momentum, hessian_eigenvalue_threshold
):
    return rgd.Adahessian(
        params,
        newton_lr,
        method="hhd",
        diag_corrected=True,
        betas=(1 - minus_momentum, 0.999),
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_hhd_hutchinson4(
    params, newton_lr, minus_momentum, minus_beta2, hessian_eigenvalue_threshold
):
    return rgd.Adahessian(
        params,
        newton_lr,
        method="hhd",
        betas=(1 - minus_momentum, 1 - minus_beta2),
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_newton_hutchinson(params, newton_lr, hessian_eigenvalue_threshold):
    return rgd.Adahessian(
        params,
        newton_lr,
        method="newton",
        hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
    )


def make_learned_kd(params, lr, minus_momentum, ke, name):
    return rgd.LearnedKineticDescent(params, ke, lr, 1 - minus_momentum)


def make_md_exp(params, lr):
    return rgd.MirrorDescent(params, lr, preconditioner="e_to_x_squared")


def make_md_pk(params, lr, little_a, big_a):
    return rgd.MirrorDescent(
        params, lr, little_a=little_a, big_a=big_a, preconditioner="e_to_x_squared"
    )


def make_backtracking_gd(params, armijo_c, tau):
    return rgd.BacktrackingGD(params, armijo_c, tau)


def parse_algo_name(algo):
    # If we included specific values for hyperparameters in the algo name,
    # we need to process accordingly
    given_params = {}
    if "-" in algo:
        algo, given_param_substr = algo.split("-")
        for p in given_param_substr.split("|"):
            p_name, p_value = p.split("=")
            # Try to parse parameter as float. If not possible, leave as string.
            try:
                p_value = float(p_value)
            except ValueError:
                pass
            given_params[p_name] = p_value

    return algo, given_params


def make_optimizer(algo):
    algo, given_params = parse_algo_name(algo)

    opt_dict = {
        "gd": make_sgd,
        "cm": make_cm,
        "nag": make_nag,
        "adam": make_adam,
        "adam3": make_adam3,
        "adam4": make_adam4,
        "rprop": make_rprop,
        "rmsprop": make_rmsprop,
        "adadelta": make_adadelta,
        "adam_centered": make_adam_centered,
        "amsgrad": make_amsgrad,
        "lbfgs": make_lbfgs,
        "lbfgs_wolfe": make_lbfgs_wolfe,
        "rgd_eu": make_rgd_eu,
        "rgd": make_rgd,
        "pk": make_power_kinetic,
        "piecewise_pk": make_piecewise_pk,
        "poly_kd": make_poly_kd,
        "exact_ckd": make_exact_ckd,
        "good_approx_ckd": make_good_approx_ckd,
        "ckd": make_ckd,
        "ckd_hinit": make_ckd_hinit,
        "ckd_hess": make_ckd_hess,
        "ckd_hess_hinit": make_ckd_hess_hinit,
        "ckd_hess_pd_threshold": make_ckd_hess_pd_threshold,
        "hhd": make_hhd,
        "hhd_upper_bounded": make_hhd_upper_bounded,
        "rgd_newton": make_rgd_newton,
        "rgd_nomo": make_rgd_nomo,
        "rgd_newton_pd": make_rgd_newton_pd,
        "hhd_abs_pd": make_hhd_abs_pd,
        "hhd_pd": make_hhd_pd,
        "hhd_diag": make_hhd_diag,
        "hhd_diag_corrected": make_hhd_diag_corrected,
        "hhd_diag_pd": make_hhd_diag_pd,
        "hhd_diag_abs_pd": make_hhd_diag_abs_pd,
        "hhd_critically_damped": make_hhd_critically_damped,
        "ckvd": make_ckvd,
        "newton": make_newton,
        "newton_2nd_momentum": make_newton_2nd_momentum,
        "newton_step_avg": make_newton_step_avg,
        "newton_pd": make_newton_pd,
        "newton_abs_pd": make_newton_abs_pd,
        "adahessian": make_adahessian,
        "adahessian3": make_adahessian3,
        "adahessian4": make_adahessian4,
        "hhd_hutchinson": make_hhd_hutchinson,
        "rgd_newton_hutchinson": make_rgd_newton_hutchinson,
        "rgd_newton1_hutchinson": make_rgd_newton1_hutchinson,
        "hhd_hutchinson_corrected": make_hhd_hutchinson_corrected,
        "hhd_hutchinson4": make_hhd_hutchinson4,
        "newton_hutchinson": make_newton_hutchinson,
        "learned_kd": make_learned_kd,
        "md_exp": make_md_exp,
        "md_pk": make_md_pk,
        "mcd": make_mcd,
        "backtracking_gd": make_backtracking_gd,
    }
    make_opt = opt_dict[algo]

    def opt(*args, **kwargs):
        kwargs = {k: v for k, v in kwargs.items() if k not in given_params}
        return make_opt(*args, **given_params, **kwargs)

    return opt


def make_config(hpspace, algo):
    algo, given_params = parse_algo_name(algo)

    def make_conf(hps):
        return {hp: hpspace[hp] for hp in hps if hp not in given_params}

    if algo in ["gd", "adam", "md_exp"]:
        return make_conf(["lr"])
    if algo in [
        "cm",
        "nag",
        "exact_ckd",
        "hhd",
        "hhd_diag",
        "hhd_diag_corrected",
        "ckvd",
        "learned_kd",
        "mcd",
    ]:
        return make_conf(["lr", "minus_momentum"])
    if algo in ["adam3", "adam_centered", "amsgrad"]:
        return make_conf(["lr", "minus_beta1", "minus_beta2"])
    if algo == "adam4":
        return make_conf(["lr", "minus_beta1", "minus_beta2", "eps"])
    if algo == "rprop":
        return make_conf(["lr", "eta1", "eta2", "step_min", "step_max"])
    if algo == "rmsprop":
        return make_conf(["lr", "minus_beta2", "centered"])
    if algo in ["adadelta"]:
        return make_conf(["lr", "minus_beta1"])
    if algo == "rgd_eu":
        return make_conf(["lr", "minus_momentum", "delta"])
    if algo == "rgd":
        return make_conf(["lr", "minus_momentum", "delta", "alpha"])
    if algo in ["rgd_nomo", "rgd_newton"]:
        return make_conf(["lr", "delta"])
    if algo == "rgd_newton_pd":
        return make_conf(["lr", "delta", "hessian_eigenvalue_threshold"])
    if algo == "hhd_upper_bounded":
        return make_conf(["lr", "minus_momentum", "conj_hessian_lipschitz"])
    if algo == "pk":
        return make_conf(["lr", "minus_momentum", "delta", "little_a", "big_a"])
    if algo == "piecewise_pk":
        return make_conf(
            ["lr", "minus_momentum", "delta", "little_a", "big_a", "piecewise_at"]
        )
    if algo == "good_approx_ckd":
        return make_conf(["lr", "minus_momentum", "alpha_newton"])
    if algo in ["ckd", "ckd_hinit"]:
        return make_conf(["lr", "minus_momentum", "num_inner_loops"])
    if algo in ["ckd_hess", "ckd_hess_hinit"]:
        return make_conf(["lr", "minus_momentum", "alpha_newton", "num_inner_loops"])
    if algo in ["ckd_hess_pd_threshold"]:
        return make_conf(
            [
                "lr",
                "minus_momentum",
                "alpha_newton",
                "hessian_eigenvalue_threshold",
            ]
        )
    if algo in ["lbfgs", "lbfgs_wolfe"]:
        return make_conf(["newton_lr", "max_iter"])
    if algo in ["hhd_pd", "hhd_diag_pd", "hhd_abs_pd", "hhd_diag_abs_pd"]:
        return make_conf(["lr", "minus_momentum", "hessian_eigenvalue_threshold"])
    if algo in ["newton", "adahessian", "hhd_critically_damped"]:
        return make_conf(["newton_lr"])
    if algo == "newton_2nd_momentum":
        return make_conf(
            ["newton_lr", "minus_momentum", "hessian_eigenvalue_threshold"]
        )
    if algo == "newton_step_avg":
        return make_conf(["newton_lr", "minus_beta1"])
    if algo in ["newton_pd", "newton_abs_pd"]:
        return make_conf(["newton_lr", "hessian_eigenvalue_threshold"])
    if algo == "adahessian3":
        return make_conf(["newton_lr", "minus_beta1", "minus_beta2"])
    if algo == "adahessian4":
        return make_conf(["newton_lr", "minus_beta1", "minus_beta2", "eps"])
    if algo in ["hhd_hutchinson", "hhd_hutchinson_corrected"]:
        return make_conf(
            ["newton_lr", "minus_momentum", "hessian_eigenvalue_threshold"]
        )
    if algo == "rgd_newton_hutchinson":
        return make_conf(["newton_lr", "delta", "hessian_eigenvalue_threshold"])
    if algo == "rgd_newton1_hutchinson":
        return make_conf(
            ["newton_lr", "minus_beta1", "delta", "hessian_eigenvalue_threshold"]
        )
    if algo == "hhd_hutchinson4":
        return make_conf(
            [
                "newton_lr",
                "minus_momentum",
                "minus_beta2",
                "hessian_eigenvalue_threshold",
            ]
        )
    if algo == "newton_hutchinson":
        return make_conf(["newton_lr", "hessian_eigenvalue_threshold"])
    if algo == "poly_kd":
        return make_conf(["lr", "minus_momentum"] + [f"c{i}" for i in range(1, 5)])
    if algo == "md_pk":
        return make_conf(["lr", "little_a", "big_a"])
    if algo == "backtracking_gd":
        return make_conf(["armijo_c", "tau"])

    raise ValueError(f"`algo` {algo} not valid.")


def set_cuda_available_devices(max_devices=None, maxLoad=0.1):
    devices = GPUtil.getAvailable(
        limit=float("inf"), maxLoad=maxLoad, maxMemory=maxLoad
    )[:max_devices]
    os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(d) for d in devices])
    return devices
