"""Main module defining various optimization methods."""
import copy
import math

import scipy.special
import torch
import torch.optim as optim
import torch.distributed as dist


class RGD(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        momentum,
        delta,
        integrator="leapfrog",
        alpha=1,
        weight_decay=0,
    ):
        if integrator not in ["symplectic_euler", "leapfrog"]:
            raise ValueError(
                "`integrator` must be either 'symplectic_euler' or 'leapfrog'"
            )

        defaults = dict(
            lr=lr,
            momentum=momentum,
            delta=delta,
            integrator=integrator,
            alpha=alpha,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            delta = group["delta"]
            integrator = group["integrator"]
            alpha = group["alpha"]
            weight_decay = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                g_k = p.grad

                if weight_decay != 0:
                    g_k = g_k.add(p, alpha=weight_decay)

                # Get v_k if present, if not, we are on first iteration and set to zero
                param_state = self.state[p]
                v_k = (
                    param_state["v_k"]
                    if "v_k" in param_state
                    else torch.zeros_like(g_k)
                )

                if momentum == 0:
                    norm_factor = (delta * g_k.square().sum() + 1).sqrt()
                    p.add_(g_k, alpha=-lr / norm_factor)

                if integrator == "symplectic_euler":
                    # v_{k+1} = momentum * v_k - lr * g_k
                    v_k.mul_(momentum).add_(g_k, alpha=-lr)

                    # x_{k+1} = x_k + v_{k+1} / sqrt(delta * ||v_k||^2 + 1)
                    norm_factor = math.sqrt(delta * torch.square(v_k).sum() + 1)
                    p.add_(v_k, alpha=1 / norm_factor)

                else:  # integrator == "leapfrog"
                    # v_{k+1/2} = sqrt(momentum) * v_k - lr * g_k
                    v_k.mul_(math.sqrt(momentum)).add_(g_k, alpha=-lr)

                    # x_{k+1} = alpha * x_{k+1/2} + (1-alpha) * x_k + v_{k+1/2} / sqrt(delta *||v_{k+1/2}||^2 + 1)
                    if alpha != 1:
                        x_k = param_state["x_k"] if "x_k" in param_state else p.clone()
                        p.multiply_(alpha)
                        p.add_(x_k, alpha=1 - alpha)

                    norm_factor = math.sqrt(delta * torch.square(v_k).sum() + 1)
                    p.add_(v_k, alpha=1 / norm_factor)

                    if alpha != 1:
                        param_state["x_k"] = p.clone()

                    # v_{k+1} = sqrt(momentum) * v_{k+1/2}
                    v_k.mul_(math.sqrt(momentum))

                    # x_{k+3/2} = x_{k+1} + sqrt(momentum) *
                    #       v_{k+1} / sqrt(momentum * delta * ||v_{k+1}||^2 + 1)
                    norm_factor = math.sqrt(
                        momentum * delta * torch.square(v_k).sum() + 1
                    )
                    p.add_(v_k, alpha=math.sqrt(momentum) / norm_factor)

                param_state["v_k"] = v_k

        return loss


class PowerKinetic(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        momentum,
        delta,
        little_a,
        big_a,
        piecewise_at=None,
        poly_coefficients=None,
        weight_decay=0,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            delta=delta,
            little_a=little_a,
            big_a=big_a,
            piecewise_at=piecewise_at,
            poly_coefficients=poly_coefficients,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            delta = group["delta"]
            little_a = group["little_a"]
            big_a = group["big_a"]
            piecewise_at = group["piecewise_at"]
            poly_coefficients = group["poly_coefficients"]
            weight_decay = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                g_k = p.grad

                if weight_decay != 0:
                    g_k = g_k.add(p, alpha=weight_decay)

                # Get v_k if present, if not, we are on first iteration and set to zero
                param_state = self.state[p]
                v_k = (
                    param_state["v_k"]
                    if "v_k" in param_state
                    else torch.zeros_like(g_k)
                )

                # v_{k+1} = momentum * v_k - lr * g_k
                if momentum != 0:
                    v_k.mul_(momentum).add_(g_k, alpha=-lr)
                else:
                    v_k.set_(-g_k)
                norm_v_k = torch.linalg.norm(v_k)

                # x_{k+1} = x_k + v_{k+1} * ||v_{k+1}||^(a-2) * (delta * ||v_{k+1}||^a + 1) ^ (A/a-1)
                if poly_coefficients is not None:
                    norm_factor = sum(
                        poly_coefficients[i - 1] * i * norm_v_k ** (i - 2)
                        for i in range(1, len(poly_coefficients) + 1)
                    )

                elif piecewise_at is None:
                    norm_factor = (delta * norm_v_k ** little_a + 1) ** (
                        big_a / little_a - 1
                    ) * norm_v_k ** (little_a - 2)

                else:
                    if norm_v_k <= piecewise_at:
                        norm_factor = (
                            big_a
                            * delta ** (big_a / little_a)
                            * piecewise_at ** (big_a - little_a)
                            * norm_v_k ** (little_a - 2)
                        )
                    else:
                        norm_factor = (
                            big_a
                            * delta ** (big_a / little_a)
                            * norm_v_k ** (big_a - 2)
                        )

                p.add_(v_k, alpha=lr * norm_factor)
                param_state["v_k"] = v_k

        return loss


def hessian_wrapper(hessian, eigenvalue_threshold, approx):
    if approx == "threshold":

        def f(x):
            eigvals, eigvecs = torch.linalg.eigh(hessian(x))
            eigvals[eigvals <= eigenvalue_threshold] = eigenvalue_threshold
            return eigvecs @ torch.diag(eigvals) @ eigvecs.T

        return f

    if approx == "threshold_abs_pd":

        def f(x):
            eigvals, eigvecs = torch.linalg.eigh(hessian(x))
            eigvals = eigvals.abs()
            eigvals[eigvals <= eigenvalue_threshold] = eigenvalue_threshold
            return eigvecs @ torch.diag(eigvals) @ eigvecs.T

        return f

    if approx == "additive":

        def f(x):
            hess = hessian(x)
            eigvals = torch.eigvalsh(hess)
            add_factor = eigenvalue_threshold - min(eigvals.min(), 0)
            return hess + add_factor * torch.eye(hess.shape[0], dtype=hess.dtype)

        return f

    if approx == "diagonal":

        def f(x):
            return hessian(x).diag().diag()

        return f

    if approx == "diagonal_corrected":

        def f(x):
            hess = hessian()
            correction = (hess ** 2).sum(dim=1)
            return (correction / hess.diag()).diag()

        return f

    if approx == "diagonal_threshold":

        def f(x):
            diag = hessian(x).diag()
            diag[diag <= eigenvalue_threshold] = eigenvalue_threshold
            return diag.diag()

        return f

    if approx == "diagonal_abs_pd":

        def f(x):
            diag = hessian(x).diag().abs()
            diag[diag <= eigenvalue_threshold] = eigenvalue_threshold
            return diag.diag()

        return f

    else:
        raise ValueError("`approx` must be one of 'threshold' or 'additive'")


class ConjugateKineticDescent(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        momentum,
        method,
        num_inner_loops=1,
        kinetic_grad=None,
        hessian=None,
        alpha=None,
        delta=None,
        conj_hessian_lipschitz=None,
        initialization="zero",
        hessian_eigenvalue_threshold=None,
        hess_approximation="threshold",
        track_iterates=False,
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            alpha=alpha,
            delta=delta,
            conj_hessian_lipschitz=conj_hessian_lipschitz,
            initialization=initialization,
        )
        super().__init__(params, defaults)
        valid_methods = [
            "exact",
            "grad_approx",
            "hess_approx",
            "inv_hess",
            "rgd_hess",
            "upper_bounded",
            "velocity",
        ]
        if method not in valid_methods:
            raise ValueError(f"`method` must be one of {valid_methods}")
        self.method = method
        self.num_inner_loops = num_inner_loops
        self.kinetic_grad = kinetic_grad

        if hessian_eigenvalue_threshold is not None:
            hessian = hessian_wrapper(
                hessian, hessian_eigenvalue_threshold, hess_approximation
            )
        self.hessian = hessian

        self.iterations = 0
        self.last_loss = None
        self.iterates = [] if track_iterates else None

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]
            alpha = group["alpha"]
            delta = group["delta"]
            conj_hessian_lipschitz = group["conj_hessian_lipschitz"]
            initialization = group["initialization"]

            for p in group["params"]:
                param_state = self.state[p]

                if "v_k" not in param_state:
                    if initialization == "zero":
                        param_state["v_k"] = torch.zeros_like(p.grad)
                    elif initialization == "heuristic":
                        param_state["v_k"] = -p.grad
                v_k = param_state["v_k"]

                if "a_k" not in param_state:
                    if initialization == "zero":
                        param_state["a_k"] = torch.zeros_like(p.grad)
                    elif initialization == "heuristic":
                        param_state["a_k"] = p.clone()
                a_k = param_state["a_k"]

                if "b_k" not in param_state:
                    if initialization == "zero":
                        param_state["b_k"] = torch.zeros_like(p.grad)
                    elif initialization == "heuristic":
                        param_state["b_k"] = p.clone()
                b_k = param_state["b_k"]

                if "x_k" not in param_state:
                    param_state["x_k"] = p.clone()
                x_k = param_state["x_k"]

                if self.method in [
                    "exact",
                    "inv_hess",
                    "rgd_hess",
                    "upper_bounded",
                ]:
                    # v_{k+1} = momentum * v_k - lr * g_k
                    v_k.mul_(momentum).add_(p.grad, alpha=-lr)

                    # y_{k+1} = grad k(v_{k+1})
                    if self.method == "exact":
                        grad_k = self.kinetic_grad(v_k)
                    else:
                        try:
                            grad_k = torch.linalg.solve(
                                self.hessian(p), v_k.unsqueeze(1)
                            )
                        except:
                            return torch.tensor(float("inf"))
                        grad_k.squeeze_()

                        if self.method == "rgd_hess":
                            grad_k = (
                                grad_k
                                / (delta * torch.dot(v_k, grad_k).abs() + 1).sqrt()
                            )

                        elif self.method == "upper_bounded":
                            grad_k = (
                                grad_k
                                - 0.5
                                * conj_hessian_lipschitz
                                * torch.linalg.norm(v_k - p.grad)
                                * (v_k - p.grad)
                            )

                    p.add_(grad_k, alpha=lr)
                    self.last_loss = loss.detach()

                elif self.method == "velocity":
                    if self.iterations % 3 == 0:
                        param_state["g_k"] = p.grad.clone()
                        param_state["x_k"] = p.clone()
                        p.copy_(a_k)
                        self.last_loss = loss

                    else:
                        damping = (1 - momentum) / lr
                        c = -param_state["g_k"] - damping * p.grad
                        try:
                            hess_inv_product = torch.linalg.solve(
                                self.hessian(p), c.unsqueeze(1)
                            )
                        except:
                            return torch.tensor(float("inf"))

                        if self.iterations % 3 == 1:
                            a_k.add_(hess_inv_product.squeeze(-1), alpha=lr)
                            param_state["a_k"] = a_k.clone()
                            p.copy_(b_k)
                        else:
                            b_k.add_(hess_inv_product.squeeze(-1), alpha=-lr)
                            param_state["b_k"] = b_k.clone()
                            p.copy_(param_state["x_k"])
                            p.add_(a_k, alpha=0.5 * lr).add_(b_k, alpha=0.5 * lr)

                elif self.method == "grad_approx":
                    iters_this_step = self.iterations % (1 + 2 * self.num_inner_loops)
                    if iters_this_step == 0:
                        v_k.mul_(momentum).add_(p.grad, alpha=-lr)
                        x_k.copy_(p.clone())
                        p.copy_(a_k)
                        self.last_loss = loss.detach()

                    elif iters_this_step < 1 + self.num_inner_loops:
                        a_k.add_(v_k, alpha=lr).add_(p.grad, alpha=-lr)
                        if iters_this_step == self.num_inner_loops:
                            p.copy_(b_k)

                    else:
                        b_k.add_(v_k, alpha=-lr).add_(p.grad, alpha=-lr)
                        if iters_this_step == 2 * self.num_inner_loops:
                            x_k.add_(a_k, alpha=0.5).add_(b_k, alpha=-0.5)
                            p.copy_(x_k)

                elif self.method == "hess_approx":
                    iters_this_step = self.iterations % (1 + 2 * self.num_inner_loops)
                    if iters_this_step == 0:
                        # damping gamma = (1 - momentum) / lr
                        param_state["v_dot"] = -(1 - momentum) / lr * v_k - p.grad
                        v_k.add_(param_state["v_dot"], alpha=lr)
                        x_k.copy_(p.clone())
                        p.copy_(a_k)
                        self.last_loss = loss.detach()

                    elif iters_this_step < 1 + self.num_inner_loops:
                        try:
                            c = alpha * (v_k - p.grad) + param_state["v_dot"]
                            hess_inv_product = torch.linalg.solve(
                                self.hessian(a_k),
                                c.unsqueeze(1),
                            )
                            a_k.add_(hess_inv_product.squeeze(-1), alpha=lr)
                        except:
                            return torch.tensor(float("inf"))
                        if iters_this_step == self.num_inner_loops:
                            p.copy_(b_k)

                    else:
                        try:
                            c = alpha * (v_k + p.grad) + param_state["v_dot"]
                            hess_inv_product = torch.linalg.solve(
                                self.hessian(b_k), c.unsqueeze(1)
                            )
                            b_k.add_(hess_inv_product.squeeze(-1), alpha=-lr)
                        except:
                            return torch.tensor(float("inf"))
                        if iters_this_step == 2 * self.num_inner_loops:
                            x_k.add_(a_k, alpha=0.5).add_(b_k, alpha=-0.5)
                            p.copy_(x_k)

        self.iterations += 1
        if self.iterates is not None:
            self.iterates.append(
                [
                    {k: v.detach().clone() for k, v in p_v.items()}
                    for p_v in self.state.values()
                ]
            )

        return self.last_loss


class CustomAdam(optim.Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, centered=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, centered=centered)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format
                    )

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]

                state["step"] += 1
                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                diff = grad - exp_avg if group["centered"] else 0
                exp_avg_sq.mul_(beta2).addcmul_(diff, diff, value=1 - beta2)
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
                    group["eps"]
                )

                step_size = group["lr"] / bias_correction1

                p.addcdiv_(exp_avg, denom, value=-step_size)

        return loss


class LearnedKineticDescent(optim.Optimizer):
    def __init__(self, params, kinetic_energy, lr=1e-2, momentum=0.8):
        defaults = dict(lr=lr, momentum=momentum)
        super().__init__(params, defaults)
        self.kinetic_energy = kinetic_energy
        self.p_history = []

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        v_k_flat = []
        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                param_state = self.state[p]
                v_k = (
                    param_state["v_k"]
                    if "v_k" in param_state
                    else torch.zeros_like(p.grad)
                )

                v_k = momentum * v_k - lr * p.grad
                param_state["v_k"] = v_k
                v_k_flat.append(v_k.flatten())

        v_k_flat = torch.cat(v_k_flat)
        ke_grad_flat = self.kinetic_energy.grad(v_k_flat)
        self.p_history.append(v_k_flat.detach())

        i = 0
        for group in self.param_groups:
            lr = group["lr"]
            for p in group["params"]:
                p.add_(ke_grad_flat[i : i + p.numel()].reshape(p.shape), alpha=lr)
                i += p.numel()

        return loss


class GoodApproxConjugateKineticDescent(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        momentum,
        method="hess",
        hessian=None,
        alpha=None,
        num_inner_loops=1,
        hessian_eigenvalue_threshold=None,
        hess_pd_approximation="threshold",
    ):
        defaults = dict(
            lr=lr,
            momentum=momentum,
            alpha=alpha,
        )
        super().__init__(params, defaults)
        if hessian_eigenvalue_threshold is not None:
            hessian = hessian_wrapper(
                hessian, hessian_eigenvalue_threshold, hess_pd_approximation
            )
        self.method = method
        self.hessian = hessian
        self.num_inner_loops = num_inner_loops
        self.break_next = False

    @torch.no_grad()
    def step(self, closure=None):
        if self.break_next:
            return torch.tensor(float("inf"))

        group = self.param_groups[0]
        lr = group["lr"]
        momentum = group["momentum"]
        alpha = group["alpha"]

        p = group["params"][0]
        param_state = self.state[p]

        if "v_k" not in param_state:
            param_state["v_k"] = torch.zeros_like(p)
        v_k = param_state["v_k"]

        if "a_k" not in param_state:
            param_state["a_k"] = torch.zeros_like(p)
        a_k = param_state["a_k"]

        if "b_k" not in param_state:
            param_state["b_k"] = torch.zeros_like(p)
        b_k = param_state["b_k"]

        if "x_k" not in param_state:
            param_state["x_k"] = p.clone()
        x_k = param_state["x_k"]

        # damping gamma = (1 - momentum) / lr
        with torch.enable_grad():
            loss = closure()
        param_state["v_dot"] = -(1 - momentum) / lr * v_k - p.grad
        v_k.add_(param_state["v_dot"], alpha=lr)
        x_k.copy_(p.clone())

        try:
            for _ in range(self.num_inner_loops):
                p.copy_(a_k)
                with torch.enable_grad():
                    closure()
                c = alpha * (v_k - p.grad) + param_state["v_dot"]
                hess_inv_product = torch.linalg.solve(self.hessian(a_k), c.unsqueeze(1))
                a_k.add_(hess_inv_product.squeeze(-1), alpha=lr)

                p.copy_(b_k)
                with torch.enable_grad():
                    closure()
                c = alpha * (v_k + p.grad) + param_state["v_dot"]
                hess_inv_product = torch.linalg.solve(self.hessian(b_k), c.unsqueeze(1))
                b_k.add_(hess_inv_product.squeeze(-1), alpha=-lr)
        except:
            self.break_next = True

        x_k.add_(a_k, alpha=0.5).add_(b_k, alpha=-0.5)
        p.copy_(x_k)

        return loss


class Newton(optim.Optimizer):
    def __init__(
        self,
        params,
        lr=1,
        hess_momentum=None,
        step_beta=None,
        hessian=None,
        hess_approximation=None,
        hessian_eigenvalue_threshold=None,
    ):
        defaults = dict(
            lr=lr,
            hess_momentum=hess_momentum,
            step_beta=step_beta,
        )
        super().__init__(params, defaults)
        if hess_approximation is not None:
            hessian = hessian_wrapper(
                hessian, hessian_eigenvalue_threshold, hess_approximation
            )
        self.hessian = hessian
        self.break_next = False

    @torch.no_grad()
    def step(self, closure=None):
        if self.break_next:
            return torch.tensor(float("inf"))

        group = self.param_groups[0]
        lr = group["lr"]
        hess_momentum = group["hess_momentum"]
        step_beta = group["step_beta"]
        p = group["params"][0]

        with torch.enable_grad():
            loss = closure()

        hess = self.hessian(p)
        if hess_momentum is not None:
            if "hessian_avg" not in self.state[p]:
                self.state[p]["hessian_avg"] = torch.zeros_like(hess)

            hessian_avg = self.state[p]["hessian_avg"]
            hessian_avg.mul_(hess_momentum).add_(self.hessian(p))
            hess = hessian_avg

        try:
            step = torch.linalg.solve(hess, p.grad.unsqueeze(1)).squeeze(-1)
        except:
            return torch.tensor(float("inf"))

        if step_beta is not None:
            if "step_avg" not in self.state[p]:
                self.state[p]["step_avg"] = torch.zeros_like(step)
            step_avg = self.state[p]["step_avg"]
            step_avg.mul_(step_beta).add(step, alpha=1 - step_beta)
            step = step_avg

        p.add_(step, alpha=-lr)

        return loss


class Adahessian(optim.Optimizer):
    """Implements Adahessian algorithm.
    It has been proposed in `ADAHESSIAN: An Adaptive Second Order Optimizer for Machine Learning`.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 0.15)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-4)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        hessian_power (float, optional): Hessian power (default: 1). You can also try 0.5. For some tasks we found this to result in better performance.
        single_gpu (Bool, optional): Do you use distributed training or not "torch.nn.parallel.DistributedDataParallel" (default: True)
    """

    def __init__(
        self,
        params,
        lr=0.15,
        betas=(0.9, 0.999),
        eps=1e-4,
        weight_decay=0,
        hessian_power=1,
        single_gpu=True,
        method=None,
        delta=None,
        diag_corrected=False,
        hessian_eigenvalue_threshold=None,
        track_iterates=False,
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= hessian_power <= 1.0:
            raise ValueError("Invalid Hessian power value: {}".format(hessian_power))
        defaults = dict(
            lr=lr,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            hessian_power=hessian_power,
            hessian_eigenvalue_threshold=hessian_eigenvalue_threshold,
            delta=delta,
        )
        self.single_gpu = single_gpu
        if method not in [None, "hhd", "newton"]:
            raise ValueError("'method' must be one of None, 'hhd', or 'newton'")
        super().__init__(params, defaults)
        self.method = method
        self.diag_corrected = diag_corrected
        self.iterates = [] if track_iterates else None
        # torch.manual_seed(0)

    def get_trace(self, params, grads):
        """
        compute the Hessian vector product with a random vector v, at the current gradient point,
        i.e., compute the gradient of <gradsH,v>.
        :param gradsH: a list of torch variables
        :return: a list of torch tensors
        """

        # Check backward was called with create_graph set to True
        for i, grad in enumerate(grads):
            if grad.grad_fn is None:
                raise RuntimeError(
                    "Gradient tensor {:} does not have grad_fn. When calling\n".format(
                        i
                    )
                    + "\t\t\t  loss.backward(), make sure the option create_graph is\n"
                    + "\t\t\t  set to True."
                )

        v = [2 * torch.randint_like(p, high=2) - 1 for p in params]

        # this is for distributed setting with single node and multi-gpus,
        # for multi nodes setting, we have not support it yet.
        if not self.single_gpu:
            for v1 in v:
                dist.all_reduce(v1)
        if not self.single_gpu:
            for v_i in v:
                v_i[v_i < 0.0] = -1.0
                v_i[v_i >= 0.0] = 1.0

        hvs = torch.autograd.grad(
            grads, params, grad_outputs=v, only_inputs=True, retain_graph=True
        )

        hutchinson_trace = []
        for hv in hvs:
            param_size = hv.size()
            if len(param_size) <= 2:  # for 0/1/2D tensor
                # Hessian diagonal block size is 1 here.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = hv.abs()

            elif len(param_size) == 4:  # Conv kernel
                # Hessian diagonal block size is 9 here: torch.sum() reduces the dim 2/3.
                # We use that torch.abs(hv * vi) = hv.abs()
                tmp_output = torch.mean(hv.abs(), dim=[2, 3], keepdim=True)
            hutchinson_trace.append(tmp_output)

        # this is for distributed setting with single node and multi-gpus,
        # for multi nodes setting, we have not support it yet.
        if not self.single_gpu:
            for output1 in hutchinson_trace:
                dist.all_reduce(output1 / torch.cuda.device_count())

        return hutchinson_trace

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            gradsH: The gradient used to compute Hessian vector product.
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        params = []
        groups = []
        grads = []

        # Flatten groups into lists, so that
        #  hut_traces can be called with lists of parameters
        #  and grads
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    params.append(p)
                    groups.append(group)
                    grads.append(p.grad)

        # get the Hessian diagonal
        hut_traces = self.get_trace(params, grads)
        if self.diag_corrected:
            hut_diags_squared = [d * d for d in self.get_trace(params, grads)]
        else:
            hut_diags_squared = [None] * len(params)
        for (p, group, grad, hut_trace, hut_diag_squared) in zip(
            params, groups, grads, hut_traces, hut_diags_squared
        ):

            state = self.state[p]

            # State initialization
            if len(state) == 0:
                state["step"] = 0
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p.data)
                # Exponential moving average of Hessian diagonal values
                state["exp_hessian_diag"] = torch.zeros_like(p.data)
                state["exp_hessian_diag_sq"] = torch.zeros_like(p.data)

            exp_avg, exp_hessian_diag, exp_hessian_diag_sq = (
                state["exp_avg"],
                state["exp_hessian_diag"],
                state["exp_hessian_diag_sq"],
            )

            beta1, beta2 = group["betas"]

            state["step"] += 1

            if self.method is None:
                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad.detach_(), alpha=1 - beta1)
                exp_hessian_diag_sq.mul_(beta2).addcmul_(
                    hut_trace, hut_trace, value=1 - beta2
                )

                bias_correction1 = 1 - beta1 ** state["step"]
                bias_correction2 = 1 - beta2 ** state["step"]

                # make the square root, and the Hessian power
                k = group["hessian_power"]
                denom = (
                    (exp_hessian_diag_sq.sqrt() ** k) / math.sqrt(bias_correction2) ** k
                ).add_(group["eps"])

                # make update
                p.data = p.data - group["lr"] * (
                    exp_avg / bias_correction1 / denom + group["weight_decay"] * p.data
                )

            elif self.method in ["hhd", "newton"]:
                if self.method == "hhd":
                    momentum_lr = group["lr"] if group["delta"] is None else 1
                    exp_avg.mul_(beta1).add_(grad.detach_(), alpha=-momentum_lr)
                    vec = exp_avg
                else:
                    vec = -grad.detach_()

                exp_hessian_diag.mul_(beta2).add_(hut_trace, alpha=1 - beta2)

                if self.diag_corrected:
                    exp_hessian_diag_sq.mul_(beta2).add_(
                        hut_diag_squared, alpha=1 - beta2
                    )
                    hess_inv_approx = (
                        exp_hessian_diag.abs()
                        .div_(exp_hessian_diag_sq)
                        .clamp_(max=1 / group["hessian_eigenvalue_threshold"])
                    )
                    # hess_inv_approx[hess_inv_approx < 0] = (
                    #    1 / group["hessian_eigenvalue_threshold"]
                    # )
                    step = hess_inv_approx * vec
                else:
                    # Apply bias correction to Hessian diagonal estimator
                    bias_correction = 1 / (1 - beta2 ** state["step"])

                    # Threshold Hessian diagonal approximation to be positive-definite
                    pd_diag = (
                        exp_hessian_diag.mul(bias_correction)
                        .abs_()
                        .clamp_(min=group["hessian_eigenvalue_threshold"])
                    )
                    step = vec / pd_diag

                if group["delta"] is not None:
                    step = step / math.sqrt(group["delta"] * (vec * step).sum() + 1)

                # Make update
                p.data = p.data + group["lr"] * step

        if self.iterates is not None:
            self.iterates.append(copy.deepcopy(self.state))

        return loss


class MirrorDescent(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        preconditioner,
        little_a,
        big_a,
        weight_decay=0,
    ):
        valid_preconditioners = ["e_to_x_squared", "pk"]
        if preconditioner not in valid_preconditioners:
            raise ValueError(f"'preconditioner' must be one of {valid_preconditioners}")

        defaults = dict(
            lr=lr,
            little_a=little_a,
            big_a=big_a,
            weight_decay=weight_decay,
        )
        super().__init__(params, defaults)
        self.preconditioner = preconditioner

    def grad_h(self, x, group):
        norm = x.norm()
        if self.preconditioner == "e_to_x_squared":
            return x * (norm ** 2 / 2).exp()

        if self.preconditioner == "pk":
            little_a, big_a = group["little_a"], group["big_a"]
            return (
                x
                * (norm ** little_a + 1) ** (big_a / little_a - 1)
                * norm ** (little_a - 2)
            )

    def grad_h_conjugate(self, p, group):
        norm = p.norm()
        if self.preconditioner == "e_to_x_squared":
            return p / norm * math.sqrt(scipy.special.lambertw(norm ** 2).real)

        if self.preconditioner == "pk":
            little_a, big_a = group["little_a"], group["big_a"]
            little_b = little_a / (little_a - 1)
            big_b = big_a / (big_a - 1)
            return (
                p
                * (norm ** little_b + 1) ** (big_b / little_b - 1)
                * norm ** (little_b - 2)
            )

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            weight_decay = group["weight_decay"]

            for p in group["params"]:
                if p.grad is None:
                    continue
                g_k = p.grad

                if weight_decay != 0:
                    g_k = g_k.add(p, alpha=weight_decay)

                dual_coord = self.grad_h(p, group) - lr * g_k
                p.copy_(self.grad_h_conjugate(dual_coord, group))

        return loss


class ModelConjugateDescent(optim.Optimizer):
    def __init__(self, params, lr, momentum, model_der):
        defaults = dict(
            lr=lr,
            momentum=momentum,
        )
        super().__init__(params, defaults)
        self.model_der = model_der

    def _convex_conjugate_grad(self, p):
        return scipy.optimize.newton(lambda x: self.model_der(x) - p, x0=0, fprime=None)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            momentum = group["momentum"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                # Get v_k if present, if not, we are on first iteration and set to zero
                param_state = self.state[p]
                v_k = (
                    param_state["v_k"]
                    if "v_k" in param_state
                    else torch.zeros_like(p.grad)
                )

                # v_{k+1} = momentum * v_k - lr * g_k
                v_k.mul_(momentum).add_(p.grad, alpha=1)

                # x_{k+1} = x_k + grad g*(||v_k||) v_k / ||v_k||
                norm_v_k = v_k.norm()
                grad_g_star = self._convex_conjugate_grad(norm_v_k)
                p.add_(v_k, alpha=-lr * grad_g_star / norm_v_k)

                param_state["v_k"] = v_k

        return loss


class BacktrackingGD(optim.Optimizer):
    def __init__(self, params, armijo_c=0.1, tau=0.5):
        defaults = dict(armijo_c=armijo_c, tau=tau)
        super().__init__(params, defaults)
        self.current_stepsize = lr0
        self.last_loss = None
        self.last_params = None
        self.last_grad = None
        self.stepsize = 1

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        armijo_c = group["armijo_c"]
        tau = group["tau"]
        p = group["params"][0]

        if self.last_loss is None or (
            loss - self.last_loss
            <= -self.current_stepsize * lr0 * self.last_grad.square().sum()
        ):  # satisfied armijo condition
            self.last_loss = loss
            self.last_params = p.clone()
            self.last_grad = p.grad.clone()
            self.current_stepsize = lr0
        else:  # decrease stepsize
            self.current_stepsize *= tau

        p.set_(self.last_params - self.current_stepsize * self.last_grad)

        """last_loss = loss
        last_params = p.clone()
        last_grad = p.grad.clone()
        self.stepsize /= tau
        for _ in range(20):
            p.set_(last_params - self.stepsize * last_grad)
            with torch.enable_grad():
                loss = closure()

            if loss - last_loss <= -self.stepsize * armijo_c * last_grad.square().sum():
                break
            self.stepsize *= tau"""

        return loss
