"""Implements NGN-M opitmizer."""

from typing import Iterable, Union, Optional, Any, Dict
import torch


class NGNM(torch.optim.Optimizer):
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
        lr: Union[float, torch.Tensor] = 1.0,
        beta: float = 0.9
        ):
        """
        Implements NGN-M optimizer.

        Parameters
        ----------
        params : iterable
            Iterable of parameters or parameter groups to optimize.
        lr : float, optional
            Normalizing constant 'c' for the step size. Default is 1.0.
        beta : float, optional
            Momentum coefficient. Default is 0.9.
        """
        defaults = {"lr": lr, "beta": beta}
        super().__init__(params, defaults)


    @torch.no_grad()
    def step(self, loss: torch.Tensor = None) -> Optional[float]:
        """
        Performs a single optimization step.
          new_p = x^{t+1}
          p = x^t
          old_p = x^{t-1}

        Parameters
        ----------
        loss : torch.Tensor, optional
            Loss tensor, required if the backward step has already been performed. Default is None.

        Returns
        -------
        torch.Tensor
            (Stochastic) loss value.
        """
        grad_norm_sq = self.compute_sq_grad_terms()

        for group in self.param_groups:
            c = group['lr']
            beta = group['beta']
            momngn = (1 - beta) * c / (1 + c * grad_norm_sq / (2 * loss))

            for p in group['params']:
                if p.grad is None:
                    continue
                state = self.state[p]

                if 'p' not in state:
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * p.grad
                else:
                    old_p = state["p"]
                    state["p"] = p.detach().clone()
                    new_p = p - momngn * p.grad + beta * (p - old_p)

                p.copy_(new_p)

        return loss

    @torch.no_grad()   
    def compute_sq_grad_terms(self):
        """Computes and returns *squared* gradient norm."""
        grad_norm_sq = 0.  # we assume all params are on the same device
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                  continue
                grad_norm_sq += torch.sum(p.grad.pow(2))    
        return grad_norm_sq
