"""Implements NGN-MDv2 opitmizer."""

from typing import Callable, Iterable, Optional, Tuple, Union, Any, Dict
import torch


class NGN_MDv2(torch.optim.Optimizer):
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
        lr: Union[float, torch.Tensor] = 1.0,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8
        ):
        """
        Implements NGN-MDv2.

        Parameters
        ----------
          params (iterable):
              Iterable of parameters to optimize or dicts defining parameter groups.
          lr (float):
              the stepsize parameter of NGN stepsize
          ep (float):
              the constant for stability (not divide by zero)
          betas : tuple, optional
            Momentum coefficient of the direction and the preconditioner. By default (0.9, 0.999).
        """

        defaults = {"lr": lr, "betas": betas, "eps": eps}
        super().__init__(params, defaults)
        self.state["num_steps"] = 0

    @torch.no_grad()
    def step(self, closure: Callable[[], float] = None, loss: torch.Tensor = None) -> Optional[float]:
        """
        Performs a single optimization step.
          new_p = x^{t+1}
          p = x^t
          old_p = x^{t-1}

        Parameters
        ----------
        loss : torch.tensor
            The loss tensor. Use this when the backward step has already been performed. By default None.
        Returns
        -------
        (Stochastic) Loss function value.
        """
        assert (closure is not None) or (loss is not None), "Either loss tensor or closure must be passed."
        assert (closure is None) or (loss is None), "Pass either the loss tensor or the closure, not both."
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.state['num_steps'] += 1
        self.compute_stepsizes_magnitudes()

        for group in self.param_groups:
            lr = group['lr']
            beta1, _ = group['betas']

            for p in group['params']:
                if p.grad is None:
                  continue
                state = self.state[p]
                D = state["D"]
                momngn = (1 - beta1) * lr * D / (1 + lr * D * p.grad.pow(2) / (2 * loss))

                if 'p' not in state:
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * p.grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * p.grad + beta1 * (p-old_p)

                p.copy_(new_p)

        return loss

    @torch.no_grad()   
    def compute_stepsizes_magnitudes(self):
        """Updates D."""
        ns = self.state['num_steps']

        for group in self.param_groups:
            _, beta2 = group['betas']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                  continue
                state = self.state[p]

                if "v" not in state:
                    state["v"] = (1-beta2) * p.grad.pow(2)
                else:
                    state["v"].mul_(beta2).add_(p.grad.pow(2), alpha=(1 - beta2))

                D = 1 / (eps + torch.sqrt(state["v"] / (1 - beta2 ** ns)))
                state["D"] = D
