% \scriptsize
\begin{algorithm}[t]
% \caption{Accumulate Gradient by Momentum Averaging (AGMA)}\label{alg:agma}
\caption{AdamW-PMA}\label{alg:agma}
\SetKwInOut{Input}{input}\SetKwInOut{Output}{output}
\scriptsize
\Input{$\gamma$(lr), $\beta_1, \beta_2$(betas), $\theta_0$(params), $f(\theta)$(objective), $\epsilon$(epsilon), $\lambda$(weight decay), $K$(accumulate iterations)}
\KwData{$m_0\gets 0$, $v_0\gets 0$}
% \Output{$\theta_t$}
\BlankLine
% $\mathcal{S}=\varnothing$\;
\For{$t=1\to \ldots$}{
    $g_t\gets \nabla_{\theta}f_t(\theta_{t-1})$\;
    $\tau\gets t\%K$\;
    \If{$\tau = 0$ and $t>0$}{
        \tcp{For every $K$ steps, there is a large update step.}
        $\gamma_t\gets\gamma$\;
        $m_t\gets \beta_1 m_{t-1} + (1-\beta_1)g_t/K$\tcp*[c]{Divide gradient by $K$ for stability.}
        $v_t\gets \beta_2 v_{t-1} + (1-\beta_2)g_t^2/K$\;
    }
    \Else{
        \tcp{For every $K$ steps, there is $K-1$ small update steps.}
        $\gamma_t\gets\gamma/\sqrt{K}$\tcp*[c]{Shrink the learning rate by $1/\sqrt{K}$.}
        $m_t\gets \frac{\tau}{\tau+1} m_{t} + \frac{1-\beta_1}{\tau+1}g_t$\tcp*[c]{Moving average instead of EMA.}
        $v_t\gets \frac{\tau}{\tau+1} v_{t} + \frac{1-\beta_2}{\tau+1}g_t^2$\;
    }
    $\hat{m}_t\gets m_t/(1-\beta_1^{t//K})$\tcp*[c]{Debias. "//" refers to division with remainder.}
    $\sqrt{\hat{v}_t}\gets \sqrt{v_t/(1-\beta_2^{t//K})}+\epsilon$\;
    $\hat{\theta}_t\gets (1-\gamma_t\lambda)\theta_{t-1}$\tcp*[c]{Weight decay.}
    $\theta_t = \hat{\theta}_t - \gamma_t \hat{m}_t/\sqrt{\hat{v}_t}$\tcp*[c]{Parameter update.}
    \If{$\tau = 0$ and $t>0$}{
        % $m_t\gets \beta_1Km_t$\;
        % $v_t\gets \beta_2Kv_t$\;
        $\hat{m}_t\gets K\hat{m}_t$\tcp*[c]{Rescale the momentum after large update step.}
        $\hat{v}_t\gets K\hat{v}_t$\;
    }
}
\Return $\theta_t$\;
\normalsize
\end{algorithm}
