import torch
from torch.optim.optimizer import Optimizer

class SGD(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        wd=0.01,
        momentum=0.0,
        nesterov=True,
    ):
        defaults = dict(
            params=params,
            lr=lr,
            wd=wd,
            momentum=momentum,
            nesterov=nesterov,
        )
        if momentum > 0.0:
            print("Enable momentum")
        
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            wd = group["wd"]
            momentum = group["momentum"]
            nesterov = group["nesterov"]
            use_momentum = momentum > 0.0

            for p in group["params"]:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]


                if use_momentum:
                    if "momentum_buffer" not in state:
                        state["momentum_buffer"] = torch.zeros_like(g)
                    buf = state["momentum_buffer"]
                    buf.mul_(momentum).add_(g, alpha=1.0 - momentum)

                    if nesterov:
                        g_update = g.add(buf, alpha=momentum)
                    else:
                        g_update = buf
                else:
                    g_update = g

                if wd > 0.0:
                    p.data.mul_(1 - lr * wd)
                p.data.add_(g_update, alpha=-lr)
        return loss

