import torch


class Adam:
    """
    to update image, there are only variances but no biases
    """
    def __init__(self, parameter, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
        self.t = 0
        self.lr, self.beta1, self.beta2, self.eps = lr, beta1, beta2, eps
        self.v = torch.zeros_like(parameter)
        self.s = torch.zeros_like(parameter)

    def update_grad(self, p_grad):
        self.t += 1
        p_grad.detach_()
        with torch.no_grad():
            self.v[:] = self.beta1 * self.v + (1 - self.beta1) * p_grad
            self.s[:] = self.beta2 * self.s + (1 - self.beta2) * torch.square(p_grad)
            v_bias_corr = self.v / (1 - self.beta1 ** self.t)
            s_bias_corr = self.s / (1 - self.beta2 ** self.t)
            grad = self.lr * v_bias_corr / (torch.sqrt(s_bias_corr) + self.eps)
        p_grad.zero_()
        return grad.detach_()
