#! -*- coding: utf-8
import torch
from torch.optim.optimizer import Optimizer

__all__ = ["DSGDOptimizer"]


class DSGDOptimizer(Optimizer):
    def __init__(self, params, lr: float = 1e-5, beta: float = 0.9, weight_decay: float = 0.0,
                 dampening: float = 0.0, nesterov: bool = False):
        # self.l2_penalty = 0.001
        # self.lr = lr

        defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay,
                        dampening=dampening, nesterov=nesterov)
        super(DSGDOptimizer, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None

        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            beta = group["beta"]
            weight_decay = group["weight_decay"]
            dampening = group["dampening"]

            for p in group['params']:
                if p.grad is None:
                    continue

                # apply weight decay
                grad = p.grad
                if weight_decay != 0.0:
                    # Nested if is necessary to bypass jitscript rules
                    if isinstance(weight_decay, torch.Tensor):
                        if weight_decay.requires_grad:
                            # usually this is the differentiable path, which is why the param.clone() is needed
                            grad = grad.addcmul_(p.clone(), weight_decay)
                        else:
                            grad = grad.add(p, alpha=weight_decay)
                    else:
                        grad = grad.add(p, alpha=weight_decay)

                if beta != 0.0:
                    state = self.state[p]
                    if "momentum_buffer" in state:
                        buf = state["momentum_buffer"]
                        buf.mul_(beta).add_(grad, alpha=1.0-dampening)
                    else:
                        buf = torch.clone(grad).detach()
                        state["momentum_buffer"] = buf

                    if group["nesterov"]:
                        grad = grad.add(buf, alpha=beta)
                    else:
                        grad = buf

                # Nested if is necessary to bypass jitscript rules
                if isinstance(lr, torch.Tensor):
                    if lr.requires_grad:
                        p.addcmul_(grad, lr, value=-1)
                    else:
                        p.add_(grad, alpha=-lr)
                else:
                    p.add_(grad, alpha=-lr)

        return loss
