import torch.optim.optimizer
from ..tensor import ManifoldParameter, ManifoldTensor
from .mixin import OptimMixin
from ..utils import copy_or_set_

__all__ = ["RiemannianSGD"]


class RiemannianSGD(OptimMixin, torch.optim.Optimizer):
    r"""
    Riemannian Stochastic Gradient Descent with the same API as :class:`torch.optim.SGD`.

    Parameters
    ----------
    params : iterable
        iterable of parameters to optimize or dicts defining
        parameter groups
    lr : float
        learning rate
    momentum : float (optional)
        momentum factor (default: 0)
    weight_decay : float (optional)
        weight decay (L2 penalty) (default: 0)
    dampening : float (optional)
        dampening for momentum (default: 0)
    nesterov : bool (optional)
        enables Nesterov momentum (default: False)

    Other Parameters
    ----------------
    stabilize : int
        Stabilize parameters if they are off-manifold due to numerical
        reasons every ``stabilize`` steps (default: ``None`` -- no stabilize)
    """

    def __init__(
        self,
        params,
        lr,
        momentum=0,
        dampening=0,
        weight_decay=0,
        nesterov=False,
        stabilize=None,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr=lr,
            momentum=momentum,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
        )
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")

        super().__init__(params, defaults, stabilize=stabilize)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()
        # with torch.no_grad():
        for group in self.param_groups:
            if "step" not in group:
                group["step"] = 0
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            dampening = group["dampening"]
            nesterov = group["nesterov"]
            learning_rate = group["lr"]
            for point in group["params"]:
                grad = point.grad
                if grad is None:
                    continue
                state = self.state[point]

                # State initialization
                if len(state) == 0:
                    if momentum > 0:
                        state["momentum_buffer"] = grad.clone()
                if isinstance(point, (ManifoldParameter, ManifoldTensor)):
                    manifold = point.manifold
                else:
                    manifold = self._default_manifold


                grad.add_(weight_decay, point)
                grad = manifold.egrad2rgrad(point, grad)

                if momentum > 0:
                    momentum_buffer = state["momentum_buffer"]
                    momentum_buffer.mul_(momentum).add_(1 - dampening, grad)
                    if nesterov:
                        grad = grad.add_(momentum, momentum_buffer)
                    else:
                        grad = momentum_buffer
                    # we have all the things projected
                    new_point, new_momentum_buffer = manifold.retr_transp(
                        point, -learning_rate * grad, momentum_buffer
                    )
                    momentum_buffer.set_(new_momentum_buffer)
                    # use copy only for user facing point
                    copy_or_set_(point, new_point)
                else:
                    new_point = manifold.retr(point, -learning_rate * grad)
                    copy_or_set_(point, new_point)

                group["step"] += 1
            if self._stabilize is not None and group["step"] % self._stabilize == 0:
                self.stabilize_group(group)
        return loss

    @torch.no_grad()
    def stabilize_group(self, group):
        for p in group["params"]:
            if not isinstance(p, (ManifoldParameter, ManifoldTensor)):
                continue
            manifold = p.manifold
            momentum = group["momentum"]
            copy_or_set_(p, manifold.projx(p))
            if momentum > 0:
                param_state = self.state[p]
                if not param_state:  # due to None grads
                    continue
                if "momentum_buffer" in param_state:
                    buf = param_state["momentum_buffer"]
                    buf.set_(manifold.proju(p, buf))
