"""Implements NGN-MDv1 opitmizer."""

from typing import Callable, Iterable, Optional, Tuple, Union, Any, Dict
import torch


class NGN_MDv1(torch.optim.Optimizer):
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
        lr: Union[float, torch.Tensor] = 1.0,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-8
        ):
        """
        Implements NGN-MDv1.

        Parameters
        ----------
          params (iterable):
              Iterable of parameters to optimize or dicts defining parameter groups.
          lr (float):
              the stepsize parameter of NGN stepsize
          ep (float):
              the constant for stability (not divide by zero)
          betas : tuple, optional
            Momentum coefficient of the direction and the preconditioner. By default (0.9, 0.999).
        """

        defaults = {"lr": lr, "betas": betas, "eps": eps}
        super().__init__(params, defaults)
        self.state["num_steps"] = 0

    @torch.no_grad()
    def step(self, closure: Callable[[], float] = None, loss: torch.Tensor = None) -> Optional[float]:
        """
        Performs a single optimization step.
          new_p = x^{t+1}
          p = x^t
          old_p = x^{t-1}

        Parameters
        ----------
        loss : torch.tensor
            The loss tensor. Use this when the backward step has already been performed. By default None.

        Returns
        -------
        (Stochastic) Loss function value.
        """
        assert (closure is not None) or (loss is not None), "Either loss tensor or closure must be passed."
        assert (closure is None) or (loss is None), "Pass either the loss tensor or the closure, not both."
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self.state['num_steps'] += 1
        grad_norm_sq = self.compute_sq_grad_terms()

        for group in self.param_groups:
            lr = group['lr']
            beta1, _ = group['betas']
            momngn = (1 - beta1) * lr / (1 + lr * grad_norm_sq / (2 * loss))

            for p in group['params']:
                if p.grad is None:
                  continue

                state = self.state[p]

                if 'p' not in state:
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * state["D"] * p.grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * state["D"] * p.grad + beta1 * (p - old_p)

                p.copy_(new_p)

        return loss

    @torch.no_grad()   
    def compute_sq_grad_terms(self):
        """Computes and returns *squared* gradient norm, updates D."""
        grad_norm_sq = 0.  # we assume all tensors are on the same device
        ns = self.state['num_steps']

        for group in self.param_groups:
            _, beta2 = group['betas']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                  continue
                state = self.state[p]

                if "v" not in state:
                    state["v"] = (1-beta2) * p.grad.pow(2)
                else:
                    state["v"].mul_(beta2).add_(p.grad.pow(2), alpha=(1 - beta2))

                D = 1 / (eps + torch.sqrt(state["v"] / (1 - beta2 ** ns)))
                grad_norm_sq += torch.sum(D * p.grad.pow(2))
                state["D"] = D

        return grad_norm_sq
