\documentclass[accepted]{uai2022} 
\usepackage[american]{babel}
\usepackage{natbib}
    \bibliographystyle{abbrvnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

\newcommand{\swap}[3][-]{#3#1#2}


%%%%%%%%%%%%%%%
\usepackage{times}  % DO NOT CHANGE THIS
\usepackage{helvet}  % DO NOT CHANGE THIS
\usepackage{courier}  % DO NOT CHANGE THIS
\usepackage{caption} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT
%\DeclareCaptionStyle{ruled}{labelfont=normalfont,labelsep=colon,strut=off} % DO NOT CHANGE THIS
%\frenchspacing  % DO NOT CHANGE THIS
%\setlength{\pdfpagewidth}{8.5in}  % DO NOT CHANGE THIS
%\setlength{\pdfpageheight}{11in}  % DO NOT CHANGE THIS
%
% These are recommended to typeset algorithms but not required. See the subsubsection on algorithms. Remove them if you don't have algorithms in your paper.
\usepackage{algorithm}
\usepackage[utf8]{inputenc} % allow utf-8 input
\usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{hyperref}       % hyperlinks
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{xcolor}         % colors

\usepackage{multirow}
\usepackage{graphicx}
\usepackage{pifont}
%%%%%%%%%%%%%%%%%%% customized by SL
\usepackage[utf8]{inputenc} % allow utf-8 input
\usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{hyperref}
\usepackage{amsfonts}
\usepackage{amssymb,color}%,amsthm}
\usepackage{bbm}
\usepackage{mathrsfs}
\usepackage{amsmath}
\usepackage{amssymb,amsthm}
\usepackage{xr}
\makeatletter
\newcommand*{\addFileDependency}[1]{% argument=file name and extension
  \typeout{(#1)}
  \@addtofilelist{#1}
  \IfFileExists{#1}{}{\typeout{No file #1.}}
}
\makeatother
\newcommand*{\myexternaldocument}[1]{%
    \externaldocument{#1}%
    \addFileDependency{#1.tex}%
    \addFileDependency{#1.aux}%
}
\myexternaldocument{zhang_207-supp}
%%%%% NEW MATH DEFINITIONS %%%%%

\usepackage{amsmath,amsfonts,bm}

% Mark sections of captions for referring to divisions of figures
\newcommand{\figleft}{{\em (Left)}}
\newcommand{\figcenter}{{\em (Center)}}
\newcommand{\figright}{{\em (Right)}}
\newcommand{\figtop}{{\em (Top)}}
\newcommand{\figbottom}{{\em (Bottom)}}
\newcommand{\captiona}{{\em (a)}}
\newcommand{\captionb}{{\em (b)}}
\newcommand{\captionc}{{\em (c)}}
\newcommand{\captiond}{{\em (d)}}

% Highlight a newly defined term
\newcommand{\newterm}[1]{{\bf #1}}


% Figure reference, lower-case.
\def\figref#1{figure~\ref{#1}}
% Figure reference, capital. For start of sentence
\def\Figref#1{Figure~\ref{#1}}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
% Section reference, lower-case.
\def\secref#1{section~\ref{#1}}
% Section reference, capital.
\def\Secref#1{Section~\ref{#1}}
% Reference to two sections.
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
% Reference to three sections.
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
% Reference to an equation, lower-case.
\def\eqref#1{(\ref{#1})}
% Reference to an equation, upper case
\def\Eqref#1{Equation~\ref{#1}}
% A raw reference to an equation---avoid using if possible
\def\plaineqref#1{\ref{#1}}
% Reference to a chapter, lower-case.
\def\chapref#1{chapter~\ref{#1}}
% Reference to an equation, upper case.
\def\Chapref#1{Chapter~\ref{#1}}
% Reference to a range of chapters
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
% Reference to an algorithm, lower-case.
\def\algref#1{algorithm~\ref{#1}}
% Reference to an algorithm, upper case.
\def\Algref#1{Algorithm~\ref{#1}}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
% Reference to a part, lower case
\def\partref#1{part~\ref{#1}}
% Reference to a part, upper case
\def\Partref#1{Part~\ref{#1}}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}

\def\ceil#1{\lceil #1 \rceil}
\def\floor#1{\lfloor #1 \rfloor}
\def\1{\bm{1}}
\newcommand{\train}{\mathcal{D}}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}

\def\eps{{\epsilon}}


% Random variables
\def\reta{{\textnormal{$\eta$}}}
\def\ra{{\textnormal{a}}}
\def\rb{{\textnormal{b}}}
\def\rc{{\textnormal{c}}}
\def\rd{{\textnormal{d}}}
\def\re{{\textnormal{e}}}
\def\rf{{\textnormal{f}}}
\def\rg{{\textnormal{g}}}
\def\rh{{\textnormal{h}}}
\def\ri{{\textnormal{i}}}
\def\rj{{\textnormal{j}}}
\def\rk{{\textnormal{k}}}
\def\rl{{\textnormal{l}}}
% rm is already a command, just don't name any random variables m
\def\rn{{\textnormal{n}}}
\def\ro{{\textnormal{o}}}
\def\rp{{\textnormal{p}}}
\def\rq{{\textnormal{q}}}
\def\rr{{\textnormal{r}}}
\def\rs{{\textnormal{s}}}
\def\rt{{\textnormal{t}}}
\def\ru{{\textnormal{u}}}
\def\rv{{\textnormal{v}}}
\def\rw{{\textnormal{w}}}
\def\rx{{\textnormal{x}}}
\def\ry{{\textnormal{y}}}
\def\rz{{\textnormal{z}}}

% Random vectors
\def\rvepsilon{{\mathbf{\epsilon}}}
\def\rvtheta{{\mathbf{\theta}}}
\def\rva{{\mathbf{a}}}
\def\rvb{{\mathbf{b}}}
\def\rvc{{\mathbf{c}}}
\def\rvd{{\mathbf{d}}}
\def\rve{{\mathbf{e}}}
\def\rvf{{\mathbf{f}}}
\def\rvg{{\mathbf{g}}}
\def\rvh{{\mathbf{h}}}
\def\rvu{{\mathbf{i}}}
\def\rvj{{\mathbf{j}}}
\def\rvk{{\mathbf{k}}}
\def\rvl{{\mathbf{l}}}
\def\rvm{{\mathbf{m}}}
\def\rvn{{\mathbf{n}}}
\def\rvo{{\mathbf{o}}}
\def\rvp{{\mathbf{p}}}
\def\rvq{{\mathbf{q}}}
\def\rvr{{\mathbf{r}}}
\def\rvs{{\mathbf{s}}}
\def\rvt{{\mathbf{t}}}
\def\rvu{{\mathbf{u}}}
\def\rvv{{\mathbf{v}}}
\def\rvw{{\mathbf{w}}}
\def\rvx{{\mathbf{x}}}
\def\rvy{{\mathbf{y}}}
\def\rvz{{\mathbf{z}}}

% Elements of random vectors
\def\erva{{\textnormal{a}}}
\def\ervb{{\textnormal{b}}}
\def\ervc{{\textnormal{c}}}
\def\ervd{{\textnormal{d}}}
\def\erve{{\textnormal{e}}}
\def\ervf{{\textnormal{f}}}
\def\ervg{{\textnormal{g}}}
\def\ervh{{\textnormal{h}}}
\def\ervi{{\textnormal{i}}}
\def\ervj{{\textnormal{j}}}
\def\ervk{{\textnormal{k}}}
\def\ervl{{\textnormal{l}}}
\def\ervm{{\textnormal{m}}}
\def\ervn{{\textnormal{n}}}
\def\ervo{{\textnormal{o}}}
\def\ervp{{\textnormal{p}}}
\def\ervq{{\textnormal{q}}}
\def\ervr{{\textnormal{r}}}
\def\ervs{{\textnormal{s}}}
\def\ervt{{\textnormal{t}}}
\def\ervu{{\textnormal{u}}}
\def\ervv{{\textnormal{v}}}
\def\ervw{{\textnormal{w}}}
\def\ervx{{\textnormal{x}}}
\def\ervy{{\textnormal{y}}}
\def\ervz{{\textnormal{z}}}

% Random matrices
\def\rmA{{\mathbf{A}}}
\def\rmB{{\mathbf{B}}}
\def\rmC{{\mathbf{C}}}
\def\rmD{{\mathbf{D}}}
\def\rmE{{\mathbf{E}}}
\def\rmF{{\mathbf{F}}}
\def\rmG{{\mathbf{G}}}
\def\rmH{{\mathbf{H}}}
\def\rmI{{\mathbf{I}}}
\def\rmJ{{\mathbf{J}}}
\def\rmK{{\mathbf{K}}}
\def\rmL{{\mathbf{L}}}
\def\rmM{{\mathbf{M}}}
\def\rmN{{\mathbf{N}}}
\def\rmO{{\mathbf{O}}}
\def\rmP{{\mathbf{P}}}
\def\rmQ{{\mathbf{Q}}}
\def\rmR{{\mathbf{R}}}
\def\rmS{{\mathbf{S}}}
\def\rmT{{\mathbf{T}}}
\def\rmU{{\mathbf{U}}}
\def\rmV{{\mathbf{V}}}
\def\rmW{{\mathbf{W}}}
\def\rmX{{\mathbf{X}}}
\def\rmY{{\mathbf{Y}}}
\def\rmZ{{\mathbf{Z}}}

% Elements of random matrices
\def\ermA{{\textnormal{A}}}
\def\ermB{{\textnormal{B}}}
\def\ermC{{\textnormal{C}}}
\def\ermD{{\textnormal{D}}}
\def\ermE{{\textnormal{E}}}
\def\ermF{{\textnormal{F}}}
\def\ermG{{\textnormal{G}}}
\def\ermH{{\textnormal{H}}}
\def\ermI{{\textnormal{I}}}
\def\ermJ{{\textnormal{J}}}
\def\ermK{{\textnormal{K}}}
\def\ermL{{\textnormal{L}}}
\def\ermM{{\textnormal{M}}}
\def\ermN{{\textnormal{N}}}
\def\ermO{{\textnormal{O}}}
\def\ermP{{\textnormal{P}}}
\def\ermQ{{\textnormal{Q}}}
\def\ermR{{\textnormal{R}}}
\def\ermS{{\textnormal{S}}}
\def\ermT{{\textnormal{T}}}
\def\ermU{{\textnormal{U}}}
\def\ermV{{\textnormal{V}}}
\def\ermW{{\textnormal{W}}}
\def\ermX{{\textnormal{X}}}
\def\ermY{{\textnormal{Y}}}
\def\ermZ{{\textnormal{Z}}}

% Vectors
\def\vzero{{\bm{0}}}
\def\vone{{\bm{1}}}
\def\vmu{{\bm{\mu}}}
\def\vtheta{{\bm{\theta}}}
\def\va{{\bm{a}}}
\def\vb{{\bm{b}}}
\def\vc{{\bm{c}}}
\def\vd{{\bm{d}}}
\def\ve{{\bm{e}}}
\def\vf{{\bm{f}}}
\def\vg{{\bm{g}}}
\def\vh{{\bm{h}}}
\def\vi{{\bm{i}}}
\def\vj{{\bm{j}}}
\def\vk{{\bm{k}}}
\def\vl{{\bm{l}}}
\def\vm{{\bm{m}}}
\def\vn{{\bm{n}}}
\def\vo{{\bm{o}}}
\def\vp{{\bm{p}}}
\def\vq{{\bm{q}}}
\def\vr{{\bm{r}}}
\def\vs{{\bm{s}}}
\def\vt{{\bm{t}}}
\def\vu{{\bm{u}}}
\def\vv{{\bm{v}}}
\def\vw{{\bm{w}}}
\def\vx{{\bm{x}}}
\def\vy{{\bm{y}}}
\def\vz{{\bm{z}}}

% Elements of vectors
\def\evalpha{{\alpha}}
\def\evbeta{{\beta}}
\def\evepsilon{{\epsilon}}
\def\evlambda{{\lambda}}
\def\evomega{{\omega}}
\def\evmu{{\mu}}
\def\evpsi{{\psi}}
\def\evsigma{{\sigma}}
\def\evtheta{{\theta}}
\def\eva{{a}}
\def\evb{{b}}
\def\evc{{c}}
\def\evd{{d}}
\def\eve{{e}}
\def\evf{{f}}
\def\evg{{g}}
\def\evh{{h}}
\def\evi{{i}}
\def\evj{{j}}
\def\evk{{k}}
\def\evl{{l}}
\def\evm{{m}}
\def\evn{{n}}
\def\evo{{o}}
\def\evp{{p}}
\def\evq{{q}}
\def\evr{{r}}
\def\evs{{s}}
\def\evt{{t}}
\def\evu{{u}}
\def\evv{{v}}
\def\evw{{w}}
\def\evx{{x}}
\def\evy{{y}}
\def\evz{{z}}

% Matrix
\def\mA{{\bm{A}}}
\def\mB{{\bm{B}}}
\def\mC{{\bm{C}}}
\def\mD{{\bm{D}}}
\def\mE{{\bm{E}}}
\def\mF{{\bm{F}}}
\def\mG{{\bm{G}}}
\def\mH{{\bm{H}}}
\def\mI{{\bm{I}}}
\def\mJ{{\bm{J}}}
\def\mK{{\bm{K}}}
\def\mL{{\bm{L}}}
\def\mM{{\bm{M}}}
\def\mN{{\bm{N}}}
\def\mO{{\bm{O}}}
\def\mP{{\bm{P}}}
\def\mQ{{\bm{Q}}}
\def\mR{{\bm{R}}}
\def\mS{{\bm{S}}}
\def\mT{{\bm{T}}}
\def\mU{{\bm{U}}}
\def\mV{{\bm{V}}}
\def\mW{{\bm{W}}}
\def\mX{{\bm{X}}}
\def\mY{{\bm{Y}}}
\def\mZ{{\bm{Z}}}
\def\mBeta{{\bm{\beta}}}
\def\mPhi{{\bm{\Phi}}}
\def\mLambda{{\bm{\Lambda}}}
\def\mSigma{{\bm{\Sigma}}}

% Tensor
\DeclareMathAlphabet{\mathsfit}{\encodingdefault}{\sfdefault}{m}{sl}
\SetMathAlphabet{\mathsfit}{bold}{\encodingdefault}{\sfdefault}{bx}{n}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
\def\tA{{\tens{A}}}
\def\tB{{\tens{B}}}
\def\tC{{\tens{C}}}
\def\tD{{\tens{D}}}
\def\tE{{\tens{E}}}
\def\tF{{\tens{F}}}
\def\tG{{\tens{G}}}
\def\tH{{\tens{H}}}
\def\tI{{\tens{I}}}
\def\tJ{{\tens{J}}}
\def\tK{{\tens{K}}}
\def\tL{{\tens{L}}}
\def\tM{{\tens{M}}}
\def\tN{{\tens{N}}}
\def\tO{{\tens{O}}}
\def\tP{{\tens{P}}}
\def\tQ{{\tens{Q}}}
\def\tR{{\tens{R}}}
\def\tS{{\tens{S}}}
\def\tT{{\tens{T}}}
\def\tU{{\tens{U}}}
\def\tV{{\tens{V}}}
\def\tW{{\tens{W}}}
\def\tX{{\tens{X}}}
\def\tY{{\tens{Y}}}
\def\tZ{{\tens{Z}}}


% Graph
\def\gA{{\mathcal{A}}}
\def\gB{{\mathcal{B}}}
\def\gC{{\mathcal{C}}}
\def\gD{{\mathcal{D}}}
\def\gE{{\mathcal{E}}}
\def\gF{{\mathcal{F}}}
\def\gG{{\mathcal{G}}}
\def\gH{{\mathcal{H}}}
\def\gI{{\mathcal{I}}}
\def\gJ{{\mathcal{J}}}
\def\gK{{\mathcal{K}}}
\def\gL{{\mathcal{L}}}
\def\gM{{\mathcal{M}}}
\def\gN{{\mathcal{N}}}
\def\gO{{\mathcal{O}}}
\def\gP{{\mathcal{P}}}
\def\gQ{{\mathcal{Q}}}
\def\gR{{\mathcal{R}}}
\def\gS{{\mathcal{S}}}
\def\gT{{\mathcal{T}}}
\def\gU{{\mathcal{U}}}
\def\gV{{\mathcal{V}}}
\def\gW{{\mathcal{W}}}
\def\gX{{\mathcal{X}}}
\def\gY{{\mathcal{Y}}}
\def\gZ{{\mathcal{Z}}}

% Sets
\def\sA{{\mathbb{A}}}
\def\sB{{\mathbb{B}}}
\def\sC{{\mathbb{C}}}
\def\sD{{\mathbb{D}}}
% Don't use a set called E, because this would be the same as our symbol
% for expectation.
\def\sF{{\mathbb{F}}}
\def\sG{{\mathbb{G}}}
\def\sH{{\mathbb{H}}}
\def\sI{{\mathbb{I}}}
\def\sJ{{\mathbb{J}}}
\def\sK{{\mathbb{K}}}
\def\sL{{\mathbb{L}}}
\def\sM{{\mathbb{M}}}
\def\sN{{\mathbb{N}}}
\def\sO{{\mathbb{O}}}
\def\sP{{\mathbb{P}}}
\def\sQ{{\mathbb{Q}}}
\def\sR{{\mathbb{R}}}
\def\sS{{\mathbb{S}}}
\def\sT{{\mathbb{T}}}
\def\sU{{\mathbb{U}}}
\def\sV{{\mathbb{V}}}
\def\sW{{\mathbb{W}}}
\def\sX{{\mathbb{X}}}
\def\sY{{\mathbb{Y}}}
\def\sZ{{\mathbb{Z}}}

% Entries of a matrix
\def\emLambda{{\Lambda}}
\def\emA{{A}}
\def\emB{{B}}
\def\emC{{C}}
\def\emD{{D}}
\def\emE{{E}}
\def\emF{{F}}
\def\emG{{G}}
\def\emH{{H}}
\def\emI{{I}}
\def\emJ{{J}}
\def\emK{{K}}
\def\emL{{L}}
\def\emM{{M}}
\def\emN{{N}}
\def\emO{{O}}
\def\emP{{P}}
\def\emQ{{Q}}
\def\emR{{R}}
\def\emS{{S}}
\def\emT{{T}}
\def\emU{{U}}
\def\emV{{V}}
\def\emW{{W}}
\def\emX{{X}}
\def\emY{{Y}}
\def\emZ{{Z}}
\def\emSigma{{\Sigma}}

% entries of a tensor
% Same font as tensor, without \bm wrapper
\newcommand{\etens}[1]{\mathsfit{#1}}
\def\etLambda{{\etens{\Lambda}}}
\def\etA{{\etens{A}}}
\def\etB{{\etens{B}}}
\def\etC{{\etens{C}}}
\def\etD{{\etens{D}}}
\def\etE{{\etens{E}}}
\def\etF{{\etens{F}}}
\def\etG{{\etens{G}}}
\def\etH{{\etens{H}}}
\def\etI{{\etens{I}}}
\def\etJ{{\etens{J}}}
\def\etK{{\etens{K}}}
\def\etL{{\etens{L}}}
\def\etM{{\etens{M}}}
\def\etN{{\etens{N}}}
\def\etO{{\etens{O}}}
\def\etP{{\etens{P}}}
\def\etQ{{\etens{Q}}}
\def\etR{{\etens{R}}}
\def\etS{{\etens{S}}}
\def\etT{{\etens{T}}}
\def\etU{{\etens{U}}}
\def\etV{{\etens{V}}}
\def\etW{{\etens{W}}}
\def\etX{{\etens{X}}}
\def\etY{{\etens{Y}}}
\def\etZ{{\etens{Z}}}

% The true underlying data generating distribution
\newcommand{\pdata}{p_{\rm{data}}}
% The empirical distribution defined by the training set
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
% The model distribution
\newcommand{\pmodel}{p_{\rm{model}}}
\newcommand{\Pmodel}{P_{\rm{model}}}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
% Stochastic autoencoder distributions
\newcommand{\pencode}{p_{\rm{encoder}}}
\newcommand{\pdecode}{p_{\rm{decoder}}}
\newcommand{\precons}{p_{\rm{reconstruct}}}

\newcommand{\laplace}{\mathrm{Laplace}} % Laplace distribution

\newcommand{\E}{\mathbb{E}}
\newcommand{\Ls}{\mathcal{L}}
\newcommand{\R}{\mathbb{R}}
\newcommand{\emp}{\tilde{p}}
\newcommand{\lr}{\alpha}
\newcommand{\reg}{\lambda}
\newcommand{\rect}{\mathrm{rectifier}}
\newcommand{\softmax}{\mathrm{softmax}}
\newcommand{\sigmoid}{\sigma}
\newcommand{\softplus}{\zeta}
\newcommand{\KL}{D_{\mathrm{KL}}}
\newcommand{\Var}{\mathrm{Var}}
\newcommand{\standarderror}{\mathrm{SE}}
\newcommand{\Cov}{\mathrm{Cov}}
% Wolfram Mathworld says $L^2$ is for function spaces and $\ell^2$ is for vectors
% But then they seem to use $L^2$ for vectors throughout the site, and so does
% wikipedia.
\newcommand{\normlzero}{L^0}
\newcommand{\normlone}{L^1}
\newcommand{\normltwo}{L^2}
\newcommand{\normlp}{L^p}
\newcommand{\normmax}{L^\infty}

\newcommand{\parents}{Pa} % See usage in notation.tex. Chosen to match Daphne's book.

\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}

\DeclareMathOperator{\sign}{sign}
\DeclareMathOperator{\Tr}{Tr}
\let\ab\allowbreak

\usepackage{adjustbox}
\usepackage{lipsum}
\usepackage{wrapfig}
\usepackage{booktabs}
\usepackage{multirow,mathtools} 
\usepackage{algorithm,algpseudocode}
\usepackage{threeparttable}


 \algnewcommand{\algorithmicforeach}{\textbf{for each}}
\algdef{SE}[FOR]{ForEach}{EndForEach}[1]
  {\algorithmicforeach\ #1\ \algorithmicdo}% \ForEach{#1}
  {\algorithmicend\ \algorithmicforeach}% \EndForEach

\usepackage{times}

%\newtheorem{proof}{\bf{Proof}}
\newtheorem{myprop}{\bf{Proposition}}
\newtheorem{mycor}{\bf{Corollary}}
\newtheorem{mythr}{\bf{Theorem}}
\newtheorem{mylemma}{\bf{Lemma}}
\newtheorem{myremark}{\bf{Remark}}
\DeclareMathOperator{\tr}{tr}
\DeclareMathOperator{\card}{card}
\DeclareMathOperator{\cov}{cov}
\DeclareMathOperator{\diag}{diag}
\DeclareMathOperator*{\minimize}{\text{minimize}}
\DeclareMathOperator*{\maximize}{\text{maximize}}
\DeclareMathOperator*{\st}{\text{subject to}}
\DeclareMathAlphabet\mathbfcal{OMS}{cmsy}{b}{n}
\newcommand{\Def}[0]{\mathrel{\mathop:}=}
\newcommand{\Deff}[0]{=\mathrel{\mathop:}}

 \newcommand{\SL}[1]{\textcolor{red}{SL: #1}}
  \newcommand{\Gaoyuan}[1]{\textcolor{salmon}{Gaoyuan: #1}}
 \newcommand{\revision}[1]{\textcolor{blue}{#1}}


\def\remark{\addtocounter{remark}{1}\def\@currentlabel{\theremark}%
\emph{Remark~\theremark}. } \makeatother
\newcommand{\ubar}[1]{\underaccent{\bar}{#1}}
\newcommand{\overbar}[1]{\mkern 1.5mu\overline{\mkern-1.5mu#1\mkern-1.5mu}\mkern 1.5mu}
\newcounter{remark}
\def\L{{\cal L}}
\def\b1{{\boldsymbol 1}}
\def\ba{\mathbf a}
\def\abx{\overbar{\bx}}
\def\aby{\overbar{\by}}
\def\azeta{\overbar{\zeta}}
\def\axi{\overbar{\xi}}
\def\abg{\overbar{g}}
\newcommand{\CT}[1]{\boldsymbol{\mathscr{\MakeUppercase{#1}}}}
\newcommand{\TU}{\CT{U}}
\def\balpha{\boldsymbol{\alpha}}
\def\btheta{\boldsymbol{\theta}}
\def\bdelta{\boldsymbol{\delta}}
\def\bbeta{\boldsymbol{\beta}}
\def\bseta{\boldsymbol{\eta}}
\def\bdelta{\boldsymbol{\delta}}
\def\bH{\boldsymbol{\mathcal{H}}}

\def\hbx{\widehat{\bx}}

\def\hbx{\widehat{\mathbf x}}
\def\hbv{\widehat{\mathbf v}}
\def\byt{\widehat{\mathbf y}}
\def\wde{\widehat{\delta}}
\def\hcA{\widehat{\mathcal A}}
\def\bYt{\widehat{\mathbf Y}}
\def\hphi{\widehat{\phi}}
\def\htheta{\widehat{\theta}}
\def\bgo{\mathbf{g}_t}

\def\tbh{\widetilde{\mathbf h}}
\def\wbx{\widetilde{\mathbf x}}
\def\wx{\widetilde{x}}
\def\wbu{\widetilde{\mathbf u}}
\def\wsigma{\widetilde{\sigma}}
\def\tbw{\widetilde{\mathbf w}}
\def\tby{\widetilde{\mathbf y}}
\def\wbv{\widetilde{\mathbf v}}
\def\wzeta{\widetilde{\zeta}}
\def\tDelta{\widetilde{\Delta}}
\def\tbH{\widetilde{\boldsymbol{\mathcal{H}}}}

\def\bg{\hat {\mathbf g}_t}
\def\bd{\vec d}
\def\be{\mathbf e}
\def\bx{\mathbf x}
\def\bh{\mathbf h}
\def\bm{\mathbf m}
\def\bn{\mathbf n}
\def\bw{\mathbf w}
\def\by{\mathbf y}
\def\bu{\mathbf u}
\def\bv{\mathbf v}
\def\bp{\mathbf p}
\def\bq{\mathbf q}
\def\bz{\mathbf z}
\def\br{\mathbf r}
\def\bA{\mathbf A}
\def\bB{\mathbf B}
\def\bC{\mathbf C}
\def\bD{\mathbf D}
\def\bE{\mathbf E}
\def\bF{\mathbf F}
\def\bG{\mathbf G}
%\def\bH{\mathbf H}
\def\bI{\mathbf I}
\def\bK{\mathbf K}
\def\bL{\mathbf L}
\def\bM{\mathbf M}
\def\bN{\mathbf N}
\def\bP{\mathbf P}
\def\bQ{\mathbf Q}
\def\bS{\mathbf S}
\def\bT{\mathbf T}
\def\bU{\mathbf U}
\def\bV{\mathbf V}
\def\bW{\mathbf W}
\def\bX{\mathbf X}
\def\bY{\mathbf Y}
\def\bLambda{\mathbf \Lambda}
\def\bOmega{\mathbf \Omega}
\def\bZ{\mathbf Z}
\def\cA{\mathcal A}
\def\cB{\mathcal B}
\def\cC{\mathcal C}



\newcommand{\T}{\scriptscriptstyle T}
\def\adag{\alpha^\dag}
\def\rhomax{\rho_\text{max}}
\def\maximize{\mathop{\text{maximize}}}
\def\minimize{\mathop{\text{minimize}}}
\def\deref#1{Definition~\ref{#1}}
\def\secref#1{Section~\ref{#1}}
\def\leref#1{Lemma~\ref{#1}}
\def\conref#1{Condition~\ref{#1}}
\def\thref#1{Theorem~\ref{#1}}
\def\remref#1{Remark~\ref{#1}}
\def\coref#1{Corollary~\ref{#1}}
\def\figref#1{Figure~\ref{#1}}
\def\figtab#1{Table~\ref{#1}}
\def\algref#1{Algorithm~\ref{#1}}
\def\asref#1{Assumption~\ref{#1}}
\def\bydef{\triangleq}
\newtheorem{lemma}{Lemma}
\newtheorem{condition}{Condition}
\newtheorem{theorem}{Theorem}
\newtheorem{corollary}{Corollary}
\newtheorem{definition}{Definition}
\newtheorem{assumption}{Assumption}


%% This part goes in preamble
\newcommand{\dummyfig}[1]{
  \centering
  \fbox{
    \begin{minipage}[c][0.2\textheight][c]{0.5\textwidth}
      \centering{#1}
    \end{minipage}
  }
}

\usepackage{color, colortbl}
\definecolor{Gray}{gray}{0.9}
\definecolor{Orange}{rgb}{1,0.5,0}
%\rowcolor{Gray}

\newcommand{\PY}[1]{\textcolor{green}{PY: #1}}
\newcommand{\GY}[1]{\textcolor{Orange}{GY: #1}}
\newcommand{\LH}[1]{\textcolor{NavyBlue}{LH: #1}}
\newcommand{\MH}[1]{\textcolor{red}{MH: #1}}
\newcommand{\XC}[1]{\textcolor{magenta}{XC: #1}}

\makeatletter
\newcommand*{\rom}[1]{\expandafter\@slowromancap\romannumeral #1@}
\makeatother
\newcommand{\mycomment}[1]{}

%%%%%%%%% notations %%%%%%%%%%%%%%
\newcommand{\layernum}{h}
\newcommand{\convacc}{\xi}
\newcommand{\layerscale}{\tau}
\newcommand{\modeldim}{d}
\newcommand{\increase}[1]{\textcolor{red}{{#1}}}
\newcommand{\decrease}[1]{\textcolor{blue}{{#1}}}
\newcommand{\high}[1]{\textcolor{purple}{{#1}}}

\newcommand{\DAT}{{\texttt{DAT}}}
\newcommand{\AT}{{\texttt{AT}}}



\title{Distributed Adversarial Training to Robustify Deep Neural Networks at Scale (Supplementary Material)}


\author[1,*]{Gaoyuan Zhang}
\author[1,*]{Songtao Lu}
\author[2]{Yihua Zhang}
\author[3]{Xiangyi Chen}
\author[1]{Pin-Yu Chen}
\author[1]{Quanfu Fan}
\author[1]{Lee Martie}
\author[1]{Lior Horesh}
\author[3]{Mingyi Hong}
\author[1,2]{Sijia Liu}
% Add affiliations after the authors
\affil[1]{%
    IBM Research
    % Yorktown Heights, NY 10598
}
\affil[2]{%
    Michigan State University
    % East Lansing, MI 48824
}
\affil[3]{%
    University of Minnesota
    % Minneapolis, MN 55455
  }
  
\affil[*]{%
    Equal Contribution
  }

\begin{document}
\onecolumn

\setcounter{section}{0}
\setcounter{figure}{0}
\makeatletter 
\renewcommand{\thefigure}{A\@arabic\c@figure}
\makeatother
\setcounter{table}{0}
\renewcommand{\thetable}{A\arabic{table}}
\setcounter{mylemma}{0}
\renewcommand{\themylemma}{A\arabic{mylemma}}
\setcounter{algorithm}{0}
\renewcommand{\thealgorithm}{A\arabic{algorithm}}
\setcounter{equation}{0}
\renewcommand{\theequation}{A\arabic{equation}}

\maketitle

\section{DAT Algorithm Framework}
\label{app: alg_DAT}

\begin{minipage}{0.95\textwidth}
\centering
\begin{algorithm}[H]
\caption{Distributed adversarial training (DAT) for solving problem \eqref{eq: prob_DAT}}
\label{alg: DAT}
\begin{algorithmic}[1]
  \State Initial $\boldsymbol{\theta}_1$,  dataset $\mathcal D^{(i)}$  for each of $M$ workers,   and $T$ iterations 
\For{Iteration $t =  1,2,\ldots, T$}
\For{Worker $i = 1,2, \ldots, M$}   \Comment{\textcolor{blue}{Worker}}
\State  Draw a finite-size data batch  $\mathcal B_{t}^{i} \subseteq \mathcal D^{(i)} $
\State  For each data sample  $\mathbf x \in \mathcal B_{t}^{i}$, call for an \textit{inner maximization oracle}:
{\small\begin{align}\label{eq: inner_max_alg}
\boldsymbol{\delta}_t^{(i)}(\mathbf x) \Def \argmax_{ \| \boldsymbol{\delta} \|_\infty \leq \epsilon }  ~  \phi(\boldsymbol{\theta}_{t}, \boldsymbol{\delta}; \mathbf  x),
\end{align}}%
\hspace*{0.4in} where we omit the label or possible pseudo-label $y$ of $\mathbf x$ for brevity
\State Computing local gradient of $f_i$ in \eqref{eq: prob_DAT} with respect to $\boldsymbol \theta$ given perturbed samples: 
{\small\begin{align}\label{eq: stoch_grad_batch}
    \mathbf g_t^{(i)} =  \lambda \mathbb E_{\mathbf x \in \mathcal B_t^{(i)}}  [ \nabla_{\boldsymbol \theta}\ell(\boldsymbol \theta_{t}; \mathbf x)  ] + \mathbb E_{\mathbf x \in \mathcal B_t^{(i)}}  [ \nabla_{\boldsymbol \theta} \phi(\boldsymbol \theta_{t}; \mathbf x + \boldsymbol \delta_t^{(i)}(\mathbf x) )  ]
\end{align}}
\State (\textit{Optional}) Call for \textit{gradient quantizer} $Q(\cdot)$ and transmit
  $ Q(\mathbf g_t^{(i)})$ to   server
\EndFor
  \State Gradient aggregation at  server:  \hfill \Comment{\textcolor{red}{Server}}
%   $\hat {\mathbf g}_t = \frac{1}{M} \sum_{i=1}^M Q(\mathbf g_t^{(i)})$
 {\small \begin{align}\label{eq: grad_agg}
  \hat {\mathbf g}_t = \textstyle \frac{1}{M} \sum_{i=1}^M Q(\mathbf g_t^{(i)})
  \end{align}}%
\State (\textit{Optional}) Call for \textit{gradient quantizer}   $\hat{\mathbf g}_t \leftarrow Q(\hat{\mathbf g}_t) $,
  and transmit  
  $ \hat {\mathbf g}_t$ to workers: 
%   }}}%
  
%   \hspace*{-0.598in}\vbox{\colorbox{non-photoblue}{\vbox{
\For{Worker $i = 1,2, \ldots, M$}
 \hfill \Comment{\textcolor{blue}{Worker}}
  \State Call for an \textit{outer minimization oracle} $\mathcal A(\cdot)$ to update $\boldsymbol \theta$:  
 { \small \begin{align}\label{eq: outer_min}
      \boldsymbol \theta_{t+1} = \mathcal A(\boldsymbol \theta_{t}   \hat {\mathbf g}_t, \eta_t), \quad \quad  \text{$\eta_t$ is learning rate}
  \end{align}}%
  %\hspace*{0.4in}
  %\MH{[shall we indicating the dependency on local data?]} 
 % where $\eta_t$ denotes a   constant learning rate at $t$
\EndFor
 % }}}%
  \EndFor
  \end{algorithmic}
% ---------------
% Now insert all to comments that we desire on specific line numbers:
% \AddNote[blue]{4}{7}{Worker}
% \AddNote[blue]{9}{10}{Server}
% ---------------
\end{algorithm}
\end{minipage}

\paragraph{Additional details on gradient quantization}
Let $b$ denote the number of bits ($b \leq 32$), and  thus there exists $s = 2^b$  quantization levels. We specify the gradient quantization operation $Q(\cdot)$ in Algorithm\,\ref{alg: DAT} as the   \textit{randomized quantizer}   \citep{alistarh2017qsgd,yu2019double}.  Formally,
the  quantization operation at the $i$th coordinate of a vector $\mathbf g$ is given by \citep{alistarh2017qsgd}
{\small\begin{align}\label{eq: rand_q}
    Q( g_i) = \| \mathbf g \|_2 \cdot \mathrm{sign}(g_i) \cdot \xi_i(g_i,s),  \quad \forall i \in \{ 1,2, \ldots, \modeldim \}.
\end{align}}%
In \eqref{eq: rand_q},  $\xi_i(g_i,s)$ is a random number drawn as follows. Given $|g_i|/\| \mathbf g \|_2 \in [l/s, (l+1)/s]$ for some $l \in \mathbb N^+$ and $0 \leq l < s$, we  then  have
{\small\begin{align}\label{eq: xi}
\xi_i(g_i,s) = \left \{ 
    \begin{array}{ll}
      l/s   & \text{with probability $1 - (s |g_i|/\| \mathbf g \|_2 - l)$}  \\
      (l+1)/s   &  \text{with probability $ (s |g_i|/\| \mathbf g \|_2 - l)$},
    \end{array}
    \right.
\end{align}}%
where $|a|$ denotes the absolute value of a scalar $a$, and  $\| \mathbf a \|_2$ denotes the $\ell_2$ norm of a vector $\mathbf a$.
The rationale behind using \eqref{eq: rand_q} is that  $Q(g_i)$ is an \textit{unbiased} estimate of $g_i$, namely,
$
\mathbb E_{\xi_i(g_i, s)}[Q(g_i)] =  g_i
$, with bounded variance. Moreover, we at most  need
 $(32 + \modeldim + b \modeldim  )$ bits to transmit the quantized  $Q(\mathbf g)$, where $32$ bits for $\| \mathbf g \|_2$, $1$ bit for sign of $g_i$ and $b$ bits for $\xi_i(g_i,s)$,
whereas it needs $32\modeldim$ bits for a single-precision %\Gaoyuan{single-precision} \LH{(double)} 
$\mathbf g$. Clearly, a small $b$ saves the communication cost.
% We will show in Sec.\,\ref{sec: exp} that   DAT, combined with gradient quantization, still leads to a competitive  performance. For example, the robust accuracy of ResNet-50 trained by $8$-bit DAT
% (performing quantization at Step\,7 of Algorithm\,\ref{alg: DAT})
% for ImageNet is just $0.55\%$  lower than the   robust accuracy achieved by the $32$-bit DAT. 
% %however, the communication cost is reduced by \SL{xxx} times.  
% Lastly, 
We note that  if every worker performs as a server in DAT, then the quantization operation at Step\,10 of Algorithm\,\ref{alg: DAT} is no longer needed. In this case, the communication network becomes fully connected. With synchronized communication, this is favored  for   training DNNs under the All-reduce operation.

\section{Theoretical Results}\label{app: thr_1}
%\SL{@Songtao, I moved this to appendix from the main section. Could you please combine it with the next theoretical section, e.g., adding assumptions subsection?}

%{\color{red}To Sijia: Please change notation. 

%i) I used $\epsilon$ for accuracy, please change another one for perturbation. 
%\SL{Please use $\convacc$.}

%ii) I used $d$ for problem dimension of vector $\btheta$ and $h$ for the total number of layers; $d_i$ is the problem dimensions for the $i$th layer. (we also need to change the $\sigma$ function). 

%\SL{Please use $\layernum$ for number of layers, and $\layerscale(\cdot)$ for layer-wise scaling function}.


%iii) Your equation \eqref{eq: outer_min} is not correct. Please see my equation in \eqref{eq.qeq}
%} 
% iv) emphasize $\mathcal{B}$ is discrete. Otherwise our proof will be problematic.} %%


In this section, we will quantify the convergence behaviour of the proposed DAT algorithm. First, we define the following notations:
\begin{equation}\label{eq.defPhi}
\Phi_i(\btheta,\bx)=\max_{\|\bdelta^{(i)}\|_{\infty}\le\epsilon}\phi(\btheta,\bdelta^{(i)};\bx),\quad \textrm{and}\quad
    \Phi_i(\btheta)=\mathbb{E}_{\bx\in\mathcal{D}^{(i)}}\Phi_i(\btheta ; \bx).
\end{equation}
We also define
\begin{equation}
    l_i(\btheta)=\mathbb{E}_{\bx\in\mathcal{D}^{(i)}} l(\btheta ; \bx),
\end{equation}
where the label $y$ of $\mathbf x$ is omitted for labeled data. 
Then, the objective function of  problem \eqref{eq: prob_DAT} can be expressed in the  compact way
\begin{equation}
    \Psi(\btheta)=\frac{1}{M}\sum^M_{i=1}\lambda l_i(\btheta)+\Phi_i(\btheta)
\end{equation}
and the optimization problem is then given by $\min_{\btheta}\Psi(\btheta)$. 

Therefore, it is clear that if a point $\btheta^{\star}$ satisfies 
\begin{equation}
    \|\nabla_{\btheta} \Psi(\btheta^{\star})\|\le\convacc,
\end{equation} then we say $\btheta^{\star}$ is a $\convacc$ approximate first-order stationary point (FOSP) of problem  \eqref{eq: prob_DAT}.

Prior to delving into the convergence analysis of DAT, we   make the following assumptions.

\subsection{Assumptions}\label{app: assumption}
%\MH{since this is in appendix, make is less dense, by separating two equations in one line into 2.}

A1. Assume objective function has layer-wise Lipschitz continuous gradients with constant $L_i$ for each layer
\begin{align}
\|\nabla_i \Psi(\btheta_{\cdot,i})-\nabla_i \Psi(\btheta'_{\cdot,i})\|\le L_i\|\btheta_{\cdot,i}-\btheta'_{\cdot,i}\|,\forall i\in[h].
\end{align}
where $\nabla_i\Psi(\btheta_{\cdot,i})$ denotes the gradient w.r.t. the variables at the $i$th layer. Also, we assume that $\Psi(\btheta)$ is lower bounded, i.e., $\Psi^{\star}:=\min_{\btheta} \Psi(\btheta)>-\infty$ and bounded gradient estimate, i.e., $\|\nabla \bg^{(i)}\|\le G$.
%\begin{align}
%\|\nabla_{\bdelta} \phi(\btheta,\bdelta, \bx)-\nabla_{\bdelta} \phi(\btheta,\bdelta',\bx)\|\le L\|\bdelta-\bdelta'\|
%\end{align}


A2. Assume that $\phi(\btheta,\bdelta;\bx)$ is strongly concave with respect to $\bdelta$ with parameter $\mu$ and has the following gradient Lipschitz continuity with constant $L_{\phi}$:
\begin{equation}
   \|\nabla_{\btheta}\phi(\btheta,\bdelta;\bx)-\nabla_{\btheta}\phi(\btheta,\bdelta';\bx)\|\le L_{\phi}\|\bdelta-\bdelta'\|.
\end{equation}

A3. Assume that the gradient estimate is unbiased and has bounded variance, i.e.,
\begin{align}
    \mathbb{E}_{\bx\in\mathcal{B}^{(i)}} [\nabla_{\btheta} l(\btheta;\bx)] =& \nabla_{\btheta} l(\btheta), \forall i,
\\
    \mathbb{E}_{\bx\in\mathcal{B}^{(i)}} [\nabla_{\btheta} \Phi(\btheta;\bx)]=&\nabla_{\btheta} \Phi(\btheta),\forall i, 
\end{align}
where recall that $\mathcal{B}^{(i)}$ denotes a data batch used at worker $i$, %\SL{[This is a little bit confusing. $i$ was used for layer index, but here used for worker right?, and in Eq. 23-24, it becomes layer index again? Do you want to use $j$ as worker here?]},
$\nabla_{\btheta} l(\btheta):=\frac{1}{M}\sum^M_{i=1}\nabla_{\btheta} l_i(\btheta)$ and $\nabla_{\btheta} \Phi(\btheta):=\frac{1}{M}\sum^M_{i=1}\nabla_{\btheta} \Phi_i(\btheta)$; and \begin{align}
    \mathbb{E}_{\bx\in\mathcal{B}^{(i)}}&\|\nabla_{\btheta} l(\btheta;\bx)-\nabla_{\btheta} l(\btheta)\|^2\le\sigma^2, \forall i \\ \mathbb{E}_{\bx\in\mathcal{B}^{(i)}}&\|\nabla_{\btheta} \Phi(\btheta;\bx)-\nabla_{\btheta} \Phi(\btheta)\|^2\le\sigma^2,\forall i.
\end{align}
Further, we define a component-wise bounded variance of the gradient estimate
\begin{align}
    \mathbb{E}_{\bx\in\mathcal{B}^{(i)}}&\|[\nabla_{\btheta} l(\btheta;\bx)]_{jk}-[\nabla_{\btheta} l(\btheta)]_{jk}\|^2\le\sigma^2_{jk}, \forall i,
    \\
    \mathbb{E}_{\bx\in\mathcal{B}^{(i)}}&\|[\nabla_{\btheta} \Phi(\btheta;\bx)]_{jk}-[\nabla_{\btheta} \Phi(\btheta)]_{jk}\|^2\le\sigma'^2_{jk},\forall i,
\end{align}
%\SL{In (23), equal to or less and equal to? Confusing on $i$, $\mathcal B^{(i)}$ should mean data batch at worker $i$, right? I did not follow $i$ and $j$ in what follows.}
where  $j$ denotes the index of the layer, and $k$ denotes the index of entry at each layer. Under A3, we have $\sum^h_{j=1}\sum^{d_j}_{k=1}\max\{\sigma^2_{jk},\sigma'^2_{jk}\}\le\sigma^2$

A4. Assume that the component wise compression error has bounded variance
\begin{equation}
\mathbb{E}[(Q([\mathbf{g}^{(i)}(\btheta)]_{jk})-[\mathbf{g}^{(i)}(\btheta)]_{jk})^2]\le \delta^2_{jk},\forall i.
%,j,k. 
\end{equation}
%\SL{Do you really need the index $k$ to represent worker $k$ here? You can say that where $\mathbf g(\boldsymbol \theta)$ represents the gradient estimate used at each worker?}
The assumption A4 is satisfied as the randomized quantization is used \citep[Lemma\,3.1]{alistarh2017qsgd}.

% \SL{You do not need this assumption, right? Since randomized quantization \citep{alistarh2017qsgd} is used in DAT.}




% \SL{@Songtao, let us do not specify $b$ as a function of $T$. The final rate should contain $b$ or $s$.}
% \MH{agreed, show all the constants and then make the discussion after theorem.} 



\subsection{Oracle of maximization}
%Let
%\begin{equation}
%c_l\le\sigma(v)\le c_u,\forall v,\quad\textrm{and}\quad d_i=\frac{d}{h}
%\end{equation}
%where $c_l,c_u>0$.

% \SL{@Songtao, please be consistent to use $(\btheta, \mathbf x)$ or $(\btheta; \bx)$. I used the latter and changed when I saw it before. Please double check.} {\color{red} should be no ($\btheta,\bx$) in this paper}
In practice, $\Phi_i(\btheta;\bx),\forall i$ may not be obtained, since the inner loop needs to iterate by the infinite number of iterations to achieve the exact maximum point. Therefore, we allow some numerical error term resulted in the maximization step at \eqref{eq: inner_max_alg}. This consideration makes the convergence analysis more realistic. 

First, we have the following criterion to measure the closeness of the approximate maximizer to the optimal one.

\begin{definition}
Under A2, if point $\bdelta(\bx)$ satisfies
\begin{equation}\label{eq.conde}
    \max_{\bdelta\le\|\epsilon\|}\left\langle \bdelta-\bdelta^*(\bx),\nabla_{\bdelta} \phi(\btheta,\bdelta^*(\bx);\bx)\right\rangle\le\varepsilon
\end{equation}
then, it is a $\varepsilon$ approximate solution to $ \bdelta^*(\bx)$, where
\begin{equation}\label{eq.optdelta}
    \bdelta^*(\bx):=\argmax_{\bdelta\le\|\epsilon\|}\phi(\btheta,\bdelta;\bx).
\end{equation}
and $\bx$ denotes the sampled data.

%\MH{use ``;" before $\bx$?}
\end{definition}
Condition \eqref{eq.conde} is standard for defining approximate solutions of an optimization problem over a compact feasible set and has been widely studied in \citep{wang2019convergence,lu2019snap}.

In the following, we can show that when the inner maximization problem is solved accurately enough, the gradients of function $\phi(\btheta,\bdelta(\bx);\bx)$ at  $\bdelta(\bx)$ and $\bdelta^*(\bx)$ are also close. A similar claim of this fact has been shown in \citep[Lemma 2]{wang2019convergence}. For completeness of the analysis, we provide the specific statement for our problem here and give the detailed proof as well.
\begin{lemma}\label{le.vererror}
Let $\bdelta^{(k)}_t$ be the $(\mu\varepsilon)/L^2_{\phi}$  approximate solution of the inner maximization problem for worker $k$, i.e., $\max_{\bdelta^{(k)}}\phi(\btheta,\bdelta^{(k)};\bx_t)$, where $\bx_t$ denotes the sampled data at the $t$th iteration of DAT. Under A2, we have
\begin{equation}
\left\|\nabla_{\btheta} \phi\left(\btheta_t,\bdelta^{(k)}_t(\bx_t);\bx_t\right)-\nabla_{\btheta} \phi\left(\btheta_t,(\bdelta^*)^{(k)}_t(\bx_t);\bx_t\right)\right\|^2\le \varepsilon. \label{eq.maxorl0}
\end{equation}
\end{lemma}


Throughout the convergence analysis, we assume that $\bdelta^{(k)}_t(\bx_t),\forall k,t$ are all the $(\mu\varepsilon)/L^2_{\phi}$ approximate solutions of the inner maximization problem. Let us define
\begin{equation}
\left\|[\nabla \phi(\btheta_t,\bdelta^{(k)}_t(\bx_t);\bx_t)]_{ij}-[\nabla \phi(\btheta_t,(\bdelta^*)^{(k)}_t(\bx_t);\bx_t]_{ij}\right\|^2=\varepsilon_{ij}.
\end{equation}
From \leref{le.vererror}, we know that when $\bdelta^{(k)}_t(\bx_t)$ is a $(\mu\varepsilon)/L^2_{\phi}$  approximate solution, then %\MH{note sure; is your $\epsilon_{i,j}$ defined to be $\le\epsilon$? why the last inequality holds?}
\begin{equation}
\sum^h_{i=1}\sum^{d_i}_{j=1}\varepsilon_{ij} =\sum^h_{i=1}\sum^{d_i}_{j=1}\left\|[\nabla \phi(\btheta_t,\bdelta^{(k)}_t(\bx_t);\bx_t)]_{ij}-[\nabla \phi(\btheta_t,(\bdelta^*)^{(k)}_t(\bx_t);\bx_t]_{ij}\right\|^2\le\varepsilon. \label{eq.maxorl}
\end{equation}

\subsection{{Formal statements of} convergence rate guarantees}


{In what follows, we provide the formal statement of convergence rate of DAT.}
In our analysis, we 
focus on the 1-sided quantization, namely, Step\,10 of Algorithm\,\ref{alg: DAT} is omitted, and
specify the outer minimization oracle  by LAMB  \citep{you2019large}, see Algorithm\,\ref{alg:p1}. The  addition and multiplication operations in LAMB are component-wise. 

% %In section, we give the convergence analysis for DAT, where 
% the outer minimization oracle uses LAMB (a layer-wise stochastic gradient descent method   \citep{you2019large} summarized in Algorithm\,\ref{alg:p1}). The  addition and multiplication operations in LAMB are component-wise. 
% % Next, we provide the formal technical convergence rate results of DAT as follows. \MH{[a formal version of T1.]}

\begin{theorem}\label{th:main}
Under A1-A4, suppose that  $\{\btheta_t\}$ is generated by DAT for a total number of $T$ iterations, 
% the inner maximizer \eqref{eq: inner_max_alg} provides $\varepsilon$-approximate solution 
% {(i.e., the $\ell_2$ norm of inner gradient is upper bounded by $ \varepsilon$)}, 
and let the problem dimension at each layer be $d_i=d/h$. Then the convergence rate of DAT is given by
% where the outer minimization oracle uses LAMB. \MH{Be more specific? What is LAM? Which iteration? We have space here.} 
% When $c_l\le\layerscale(v)\le c_u,\forall v$, $\kappa=c_u/c_l$, and the problem dimension at each layer is $d_i=d/h$,
\begin{align}
\frac{1}{T}\sum^{T}_{t=1}\mathbb{E}\|\nabla_{\btheta} \Psi(\btheta_t)\|^2\le&\frac{\Delta_{\Psi}}{\eta_t c_l CT}+2\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)+4\delta^2 +
\frac{\kappa\sqrt{3}}{C}\|\boldsymbol{\chi}\|_1+\frac{\eta_t c_u\kappa \|L\|_1}{2C}.
\end{align}
where $\Delta_{\Psi}:=\mathbb{E}[\Psi(\btheta_1)]-\Psi^{\star}]$, 
{$\eta_t$ is the learning rate, $\kappa=c_u/c_l$, $c_l$ and $c_u$ are constants used in LALR \eqref{eq: ada_learn},}
$\boldsymbol{\chi}$ is an error term  with the $(ih+j)$th entry being $\sqrt{\frac{(1+\lambda)\sigma^2_{ij}}{MB}+\varepsilon_{ij}+\delta^2_{ij}}$,
{$\varepsilon$ and $\varepsilon_{ij}$ were given in \eqref{eq.maxorl},}
$L=[L_1,\ldots,L_h]^T$, $C=\frac{1}{4}\sqrt{\frac{h(1-\beta_2)}{G^2d}}$, $0<\beta_2<1$ is given  in LAMB, $B=\min\{|\mathcal{B}^{(i)}|,\forall i\}$, and {$G$ is given in A1.}
%$\varepsilon$ is the predefined error threshold that the inner maximization oracle needs to achieve. \SL{$G$ is not defined? }
\end{theorem}

% \remark Note that our convergence results is general in the sense that we allow some numerical error caused by solving the inner maximization problem. For DAT shown in \algref{alg: DAT}, it only show the case where $\varepsilon=0$. In practice, when the number of step in the inner loop is large, $\varepsilon$ is $\mathcal{O}(\convacc)$.

% \MH{does $\Psi(\theta)$ has a lower bound? the current $\Delta_{\Psi}$ definition does not make sense, it is not clear if this is a positive and constant number.}

% \MH{"When the batch size is large, i.e., $\mathcal{O}(\sqrt{T})$" what does it mean? should say when batch size = $O(\sqrt{T}$?)} \MH{Should compare this effect with LAMB analysis, and say the similarity/differences.}

\remark When the batch size is large, i.e., $B\sim\sqrt{T}$, then the gradient estimate error will be $\mathcal{O}(\sigma^2/\sqrt{T})$. Further, it is worth noting that different from the convergence results of LAMB, there is a linear speedup of deceasing the gradient estimate error in DAT with respect to $M$, i.e., $\mathcal{O}(\sigma^2/(M\sqrt{T}))$, which is the advantage of using      multiple computing nodes. %resources or distributed processing.

\remark Note that A4 implies $\mathbb{E}[(Q([\mathbf{g}^{(k)}(\btheta)]_{ij})-[\mathbf{g}^{(k)}(\btheta)]_{ij}\|^2]\le \sum^h_{i=1}\sum^{d_i}_{j=1}\delta^2_{ij}:=\delta^2
$. 
%\SL{[Notation on $i$ and $j$ are confusing.]}.
From \citep[Lemma 3.1]{alistarh2017qsgd}, we know that $\delta^2\le \min\{ d/s^2, \sqrt{d}/s \} G^2$. Recall that $s = 2^b$, where $b$ is the number of quantization bits.
%so when $b\sim\Omega(\log T)$ then $\delta^2\sim\mathcal{O}(1/T)$.

Therefore, with a  proper choice of the parameters, we can have the following convergence result that has been shown in \thref{th:main_simplify}.


% \SL{[We need to make the following Corollary  consistent with \thref{th:main_simplify}.]}
\begin{corollary}
%Under A1-A4, suppose that sequence $\{\btheta_t\}$ is generated by DAT for a total number of $T$ iterations, where $\btheta_t$ is updated by LAMB. When $c_l\le\layerscale(v)\le c_u,\forall v$, $\kappa=c_u/c_l$, $d_i=d/h$, and
Under the same conditions of Theorem\,\ref{th:main}, if we choose
\begin{equation}
    \eta_t\sim\mathcal{O}(1/\sqrt{T}),\quad  \varepsilon\sim\mathcal{O}(\convacc^2),
\end{equation}
 we then have 
 \begin{align}
\frac{1}{T}\sum^{T}_{t=1}\mathbb{E}\|\nabla_{\btheta} \Psi(\btheta_t)\|^2\le\frac{\Delta_{\Psi}}{c_l C\sqrt{T}}+\frac{(1+\lambda)\sigma^2}{MB}
+\frac{ c_u\kappa \|L\|_1}{2C\sqrt{T}}+\mathcal{O}\left(\convacc,\frac{\sigma}{\sqrt{MT}},\min\left\{ \frac{d}{4^b}, \frac{\sqrt{d}}{2^b} \right\}\right).
\end{align}
\end{corollary}

In summary, when the batch size is large enough, DAT converges to a first-order stationary point of problem \eqref{eq: prob_DAT} and there is a linear speed-up in terms of $M$ with respect to $\sigma^2$.  Next, we provide the details of the proof.%This result have been given in \thref{th:main_simplify}.
%\section{Related Work}

\section{Proof Details}
\label{app: analysis}

\subsection{Preliminaries}
In the proof, we use the following inequality and notations.

1. Young’s inequality with parameter $\epsilon$ is
\begin{equation}
\langle \bx,\by\rangle\le\frac{1}{2\epsilon}\|\bx\|^2 + \frac{\epsilon}{2}\|\by\|^2,
\end{equation}
where $\bx,\by$ are two vectors.

2. Define the historical trajectory of the iterates as $\mathcal{F}_t=\{\btheta_{t-1},\ldots,\btheta_1\}$.


3. We denote vector $[\bx]_i$ as the parameters at the $i$th layer of the neural net and $[\bx]_{ij}$ represents the $j$th entry of the parameter at the $i$th layer.

4. We define
\begin{equation}\label{eq.qeq}
    \bgo:= \frac{1}{M}\sum^M_{i=1} \mathbb{E}_{\bx_t\in\mathcal{B}^{(i)}}\left(\lambda\nabla l(\btheta_t;\bx_t)+\nabla_{\btheta}\phi(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t)\right)=\frac{1}{M} \sum_{i=1}^M \mathbf g_t^{(i)}.
\end{equation}


\subsection{Details of LAMB algorithm}
\begin{algorithm}[H]
\caption{{LAMB \citep{you2019large}}}% $(\bW^{(0)},L_{\max},L,\rho,\beta,\delta,\Delta f)$}
\label{alg:p1}
\begin{algorithmic}
\State Input: learning rate $\eta_t$, $0<\beta_1,\beta_2<1$, scaling function $\layerscale(\cdot)$, $\zeta>0$
\For{$t=1,\ldots$}
\State $\bm_t=\beta_1\bm_{t-1}+(1-\beta_1)\bg$, {where $\hat{\mathbf g}_t$ is given by \eqref{eq: grad_agg}}
\State $\bv_t=\beta_2\bv_{t-1}+(1-\beta_2)\bg^2$
\State $\bm_t=\bm_t/(1-\beta^t_1)$
\State $\bv_t=\bv_t/(1-\beta^t_2)$
\State Compute ratio $\bu_t=\frac{\bm_t}{\sqrt{\bv_t}+\zeta}$
\EndFor
\State Update
\begin{equation}
\btheta_{t+1,i}=\btheta_{t,i}-\frac{\eta_t\layerscale(\|\btheta_{t,i}\|)}{\|\bu_{t,i}\|}\bu_{t,i}.\label{eq.upth}
\end{equation}
\end{algorithmic}
\end{algorithm}



%\MH{put this before the theorem statement?}
\subsection{Proof of \leref{le.vererror}}
\begin{proof}
From A2, we have
\begin{equation}
    \left\|\nabla \phi\left(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t\right)-\nabla \phi\left(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t\right)\right\|\le L_{\phi}\|\bdelta^{(i)}_t(\bx_t)-(\bdelta^*)^{(i)}_t(\bx_t)\|.\label{eq.lipphi}
\end{equation}

Also, we know that function $\phi(\btheta,\bdelta,\bx)$ is strongly concave with respect to $\bdelta$, so we have
\begin{multline}
\mu\|\bdelta^{(i)}_t(\bx_t)-(\bdelta^*)^{(i)}_t(\bx_t)\|
\\
\le\left\langle\nabla_{\bdelta}\phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)-\nabla_{\bdelta}\phi(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t),\bdelta^{(i)}_t(\bx_t)- (\bdelta^*)^{(i)}_t(\bx_t)\right\rangle.\label{eq.stonrc}
\end{multline}

Next, we have two conditions about the qualities of solutions $\bdelta^{(i)}_t(\bx_t)$ and $(\bdelta^*)^{(i)}_t(\bx_t)$.
First, we know that $\bdelta^{(i)}_t(\bx_t)$ is a-$\varepsilon$ approximate solution to $(\bdelta^*)^{(i)}_t(\bx_t)$, so we have
\begin{equation}
    \left\langle(\bdelta^*)^{(i)}_t(\bx_t)-\bdelta^{(i)}_t(\bx_t),\nabla_{\bdelta}\phi(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t)\right\rangle\le\varepsilon.
\end{equation}
Second, since $(\bdelta^*)^{(i)}_t(\bx_t)$ is the optimal solution, it satisfies
\begin{equation}
    \left\langle(\bdelta^{(i)}_t(\bx_t)-(\bdelta^*)^{(i)}_t(\bx_t),\nabla_{\bdelta}\phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)\right\rangle\le0.
\end{equation}
Adding them together, we can obtain
\begin{equation}
    \left\langle\bdelta^{(i)}_t(\bx_t)-(\bdelta^*)^{(i)}_t(\bx_t), \nabla_{\bdelta}\phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)-\nabla_{\bdelta}\phi(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t)\right\rangle\le\varepsilon.\label{eq.optc}
\end{equation}
Substituting \eqref{eq.optc} into \eqref{eq.stonrc}, we can get
\begin{equation}
    \mu\|\bdelta^{(i)}_t(\bx_t)-(\bdelta^*)^{(i)}_t(\bx_t)\|^2\le\varepsilon.
\end{equation}

Combining \eqref{eq.lipphi}, we have
\begin{equation}
    \left\|\nabla \phi(\btheta_t,\bdelta^{(i)}_t(\bx_t);\bx_t)-\nabla \phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)\right\|^2\le L^2_{\phi}\frac{\varepsilon}{\mu}.
\end{equation}
\end{proof}

\subsection{Descent of quantized LAMB}
First, we provide the following lemma as a stepping stone for the subsequent analysis.
\begin{lemma}\label{le.desc}
Under A1--A3,  suppose that sequence $\{\btheta_t\}$ is generated by DAT. Then, we have
\begin{equation}
\mathbb{E}[-\langle\nabla \Psi(\btheta_t),\bg\rangle] \le-\frac{\mathbb{E}\|\nabla\Psi(\btheta_t)\|^2}{2}+\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\label{eq.deskey}.
\end{equation}
\end{lemma}
\begin{proof}
From \eqref{eq.optdelta}, \eqref{eq.defPhi} and A2, we know that 
\begin{equation}
    \nabla_{\btheta} \Phi_i(\btheta,\bx) = \nabla_{\btheta} \phi(\btheta,(\bdelta^*)^{(i)}(\bx);\bx),
\end{equation}
so we can get
\begin{align}
  \nabla_{\btheta}\Psi(\btheta)=&\frac{1}{M}\sum^M_{i=1}\lambda \nabla_{\btheta}l_i(\btheta)+\nabla_{\btheta}\Phi_i(\btheta) 
  \\
  =& \lambda\nabla_{\btheta} l(\btheta)+\frac{1}{M}\sum^M_{i=1} \mathbb{E}_{\bx\in\mathcal{D}^{(i)}}\nabla_{\btheta}\phi(\btheta,(\bdelta^*)^{(i)}(\bx);\bx)
  \\
  :=& \bar{\mathbf{g}}(\btheta).\label{eq.unb}
\end{align}
%\XC{missing $\lambda$ ?}

Then, we have
\begin{align}
    \mathbb{E}\langle \nabla \Psi(\btheta_t),\bgo\rangle=&\mathbb{E}\langle \nabla \Psi(\btheta_t),\bar{\mathbf{g}}_t\rangle+\mathbb{E}\langle \nabla \Psi(\btheta_t), \bgo-\bar{\mathbf{g}}_t\rangle
    \\
    =&\mathbb{E}_{\mathcal{F}_t}\mathbb{E}_{\bx_t|\mathcal{F}_t}\langle \nabla \Psi(\btheta_t),\bar{\mathbf{g}}_t\rangle+\mathbb{E}\langle \nabla \Psi(\btheta_t), \bgo-\bar{\mathbf{g}}_t\rangle
    \\
    \mathop{=}\limits^{\eqref{eq.unb}}&\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2+\mathbb{E}\langle \nabla \Psi(\btheta_t), \bgo-\bar{\mathbf{g}}_t\rangle
     \\
    =&\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2+\mathbb{E}\langle \nabla \Psi(\btheta_t),\bgo-\bgo^*\rangle+\mathbb{E}\langle \nabla \Psi(\btheta_t), \bgo^*-\bar{\mathbf{g}}_t\rangle
\end{align}
where
\begin{equation}
    \bar{\mathbf{g}}_t:=\frac{1}{M}\sum^M_{i=1} \mathbb{E}_{\bx_t\in\mathcal{D}^{(i)}}\left(\lambda\nabla l(\btheta_t,\bx_t)+\nabla_{\btheta}\phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)\right)=\lambda\nabla l(\btheta_t)+\nabla \Phi(\btheta_t),\label{eq.defbgt}
\end{equation}
%\XC{$\Phi$ instead of $\Psi$ at the end of equation?}
and 
\begin{equation}
    \bgo^*:=\frac{1}{M}\sum^M_{i=1} \mathbb{E}_{\bx_t\in \mathcal{B}^{(i)}}\left(\lambda\nabla l(\btheta_t,\bx_t)+\nabla_{\btheta}\phi(\btheta_t,(\bdelta^*)^{(i)}_t(\bx_t);\bx_t)\right).\label{eq.defbgtp}
\end{equation}

Next, we can quantify the different between $\bgo$ and $\bgo^*$ by gradient Lipschitz continuity of function $\layerscale(\cdot)$ as the following
\begin{equation}
\mathbb{E}\|\bgo-\bgo^*\|^2
\mathop{\le}\limits^{(a)}  \frac{1}{M}\sum^M_{i=1}\mathbb{E}_{\mathcal{F}_t}\mathbb{E}_{\bx_t|\mathcal{F}_t}\left[\|\nabla_{\btheta} \phi(\btheta_t,(\bdelta^*)^{(i)}(\bx_t);\bx_t)- \nabla_{\btheta} \phi(\btheta_t,\bdelta^{(i)}(\bx_t);\bx_t)\|^2\right]
\mathop{\le}\limits^{\eqref{eq.maxorl}} \varepsilon\label{eq.ue}
\end{equation}
where in $(a)$ we use Jensen's inequality.

And the difference between $\bar{\mathbf{g}}_t$ and $\bgo^*$ can be upper bounded by 
\begin{align}\notag
    \mathbb{E}\|\bar{\mathbf{g}}_t-\bgo^*\|^2=&\mathbb{E}_{\mathcal{F}_t}\left\|\frac{1}{M}\sum^M_{i=1}\mathbb{E}_{\bx_t|\mathcal{F}_t}\nabla_{\btheta} \phi(\btheta_t,(\bdelta^*)^{(i)}(\bx_t);\bx_t)- \nabla_{\btheta} \phi(\btheta_t)\right\|^2
\\
&+\lambda \mathbb{E}_{\mathcal{F}_t}\left\|\frac{1}{M}\sum^M_{i=1}\mathbb{E}_{\bx_t|\mathcal{F}_t}\nabla l(\btheta_t;\bx_t)-\nabla l(\btheta_t)\right\|^2
\\
\mathop{=}\limits^{A3} &\frac{(1+\lambda)\sigma^2}{MB}.\label{eq.bdggp}
\end{align}

Applying Young’s inequality with parameter 2, we have
\begin{align}
\mathbb{E}[-\langle\nabla \Psi(\btheta_t),\bgo\rangle]\le & -\mathbb{E}\|\nabla\Psi(\btheta_t)\|^2+\frac{\mathbb{E}\|\nabla\Psi(\btheta_t)\|^2}{2}+\mathbb{E}\|\bar{\mathbf{g}}_t-\bgo^*\|^2+\mathbb{E}\|\bgo^*-\bgo\|^2
\\
\mathop{\le}\limits^{\eqref{eq.ue}}&-\frac{\mathbb{E}\|\nabla\Psi(\btheta_t)\|^2}{2}+\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}.
\end{align}


\end{proof}

%{\bf Assumption 2}
%We assume that
%\begin{equation}
%\mathbb{E}[\|Q(\widehat{g}(\btheta))-\widehat{g}(\btheta)\|^2]\le \delta^2
%\end{equation}
%or

\subsection{Proof of \thref{th:main}}
\begin{proof}
%Define a virtual sequence:
%\begin{equation}
%r_t=\frac{\widehat{g}_t}{\sqrt{v_t}+\epsilon}
%\end{equation}
We set $\beta_1=0$ in LAMB for simplicity. From gradient Lipschitz continuity, we have
\begin{align}
\Psi(\btheta_{t+1})\mathop{\le}\limits^{A1} &\Psi(\btheta_t)+\sum^h_{i=1}\langle[\nabla_{\btheta} \Psi(\btheta_{t})]_i,\btheta_{t+1,i}-\btheta_{t,i}\rangle+\sum^h_{i=1}\frac{L_i}{2}\|\btheta_{t+1,i}-\btheta_{t,i}\|^2
\\
\mathop{\le}\limits^{(a)} &\Psi(\btheta_t)\underbrace{-\eta_t\sum^h_{i=1}\sum^{d_i}_{j=1}\layerscale(\|\btheta_{t,i}\|)\left\langle[\nabla \Psi(\btheta_{t})]_{ij},\frac{[\bu_{t}]_{ij}}{\|\bu_{t,i}\|}\right\rangle}_{:=\boldsymbol{\mathcal{R}}}+\sum^h_{i=1}\frac{\eta^2_t c^2_u L_i}{2},\label{eq.Phibd}
\end{align}
where in $(a)$ we use \eqref{eq.upth}, and the upper bound of $\layerscale(\|\btheta_{t,i}\|)$.


Next, we split term $R$ as two parts by leveraging $\textrm{sign}([\nabla\Psi(\btheta_t)]_{ij})$ and $\textrm{sign}([\bu_{t}]_{ij})$ as follows.
\begin{align}\notag
\boldsymbol{\mathcal{R}}=&-\eta_t\sum^h_{i=1}\sum^{d_i}_{j=1}\layerscale(\|\btheta_{t,i}\|)[\nabla \Psi(\btheta_t)]_{ij}\frac{[\bu_{t}]_{ij}}{\|\bu_{t,i}\|}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})=\textrm{sign}([\bu_{t}]_{ij})\right)
\\
&-\eta_t\sum^h_{i=1}\sum^{d_i}_{j=1}\layerscale(\|\btheta_{t,i}\|)[\nabla \Psi(\btheta_t)]_{ij}\frac{[\bu_{t}]_{ij}}{\|\bu_{t,i}\|}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bu_{t}]_{ij})\right)
\\\notag
\mathop{\le}\limits^{(a)}&-\eta_t c_l\sum^h_{i=1}\sum^{d_i}_{j=1}\sqrt{\frac{1-\beta_2}{G^2d_i}}[\nabla \Psi(\btheta_t)]_{ij} [\bg]_{ij}\mathbbm{1}\left(\textrm{sign}([\nabla [\Psi(\btheta_t)]_{ij})=\textrm{sign}([\bg]_{ij})\right)
\\
&-\eta_t\sum^h_{i=1}\sum^{d_i}_{j=1}\layerscale(\|\btheta_{t,i}\|)[\nabla \Psi(\btheta_t)]_{ij}\frac{[\bu_{t}]_{ij}}{\|\bu_{t,i}\|}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bu_t]_{ij})\right)
\\\notag
\mathop{\le}\limits^{(b)}&-\eta_tc_l\sum^h_{i=1}\sum^{d_i}_{j=1}\sqrt{\frac{1-\beta_2}{G^2d_i}}[\nabla \Psi(\btheta_t)]_{ij}  [\bg]_{ij}
\\
&-\eta_t\sum^h_{i=1}\sum^{d_i}_{j=1}\layerscale(\|\btheta_{t,i}\|)[\nabla \Psi(\btheta_t)]_{ij}\frac{[\bu_{t}]_{ij}}{\|\bu_{t,i}\|}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bu_{t}]_{ij})\right).\label{eq.t1bd}
\end{align}
where in $(a)$ we use the fact that $\|\bu_{t,i}\|\le \sqrt{\frac{d_i}{1-\beta_2}}$ and $\sqrt{\bv_t}\le G$, and in $(b)$ we add 
\begin{equation}
-\eta_t c_l\sum^h_{i=1}\sum^{d_i}_{j=1}\sqrt{\frac{1-\beta_2}{G^2d_i}}[\nabla \Psi(\btheta_t)]_{ij}[\bg]_{ij}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bg]_{ij})\right)\ge0.
\end{equation}

Taking expectation on both sides of \eqref{eq.t1bd}, we have the following:
\begin{align}\notag
\mathbb{E}[\boldsymbol{\mathcal{R}}]\le&\underbrace{-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\sum^h_{i=1}\sum^{d_i}_{j=1}\mathbb{E}[[\nabla \Psi(\btheta_t)]_{ij}[\bg]_{ij}}_{:=\boldsymbol{\mathcal{U}}}
\\
&+\underbrace{\eta_t c_u\sum^h_{i=1}\sum^{d_i}_{j=1}\mathbb{E}\left[[\nabla \Psi(\btheta_t)]_{ij}\mathbbm{1}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bu_t]_{ij})\right)\right]}_{:=\boldsymbol{\mathcal{V}}}.
\end{align}

Next, we will get the upper bounds of $\boldsymbol{\mathcal{U}}$ and $\boldsymbol{\mathcal{V}}$ separably as follows. First, we write the inner product between $[\nabla \Psi(\btheta)]_{ij}$ and $[\bg]_{ij}$ more compactly,
\begin{align}
\boldsymbol{\mathcal{U}}\le&-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\sum^h_{i=1}\mathbb{E}\left\langle [\nabla \Psi(\btheta)]_i, [\bg]_i\right\rangle
\\
\le&-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\sum^h_{i=1}\mathbb{E}\left\langle [\nabla \Psi(\btheta_t)]_i,[\bg]_i-[\bgo]_i+[\bgo]_i\right\rangle
\\
\le&-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}
\left(\mathbb{E}\left\langle\nabla\Psi(\btheta),\bgo\right\rangle+\sum^h_{i=1}\mathbb{E}\left\langle [\nabla \Psi(\btheta_t)]_i,[\bg]_i-[\bgo]_i\right\rangle\right).
\end{align}

Applying \leref{le.desc}, we can get
\begin{align}\notag
\boldsymbol{\mathcal{U}}\mathop{\le}\limits^{\eqref{eq.deskey}} &-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\frac{1}{2}\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2
+\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)
\\
&-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\sum^h_{i=1}\mathbb{E}\left\langle [\nabla \Psi(\btheta_t)]_i,[\bg]_i -[\bgo]_i\right\rangle
\\\notag
\mathop{\le}\limits^{(a)}&-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\frac{1}{2}\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2
+\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)
\\
&+\frac{\eta_t c_l}{4}\sqrt{\frac{h(1-\beta_2)}{G^2d}}\mathbb{E}\| \nabla \Psi(\btheta_t)\|^2
+c_l\eta_t\sqrt{\frac{h(1-\beta_2)}{G^2d}}\mathbb{E}\|\bg-\bgo\|^2
\\\notag
\mathop{\le}\limits^{(b)} &-\frac{\eta_t c_l}{4}\sqrt{\frac{h(1-\beta_2)}{G^2d}}\frac{1}{2}\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2+\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)
\\
&+\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\delta^2\label{eq.ubd}
\end{align}
where we use the in $(a)$ we use Young's inequality (with parameter $2$),  and in $(b)$ we have
\begin{equation}
    \mathbb{E}\|\bg-\bgo\|^2=\mathbb{E}\left\|\frac{1}{M}\sum^M_{i=1}Q(\bgo^{(i)})-\bgo^{(i)}\right\|^2\mathop{\le}\limits^{A4}\delta^2.\label{eq.bdq}
\end{equation}

Second, we give the upper of $\boldsymbol{\mathcal{V}}$:
\begin{align}\label{eq.vbd}
\boldsymbol{\mathcal{V}}\le&\eta_t c_u\sum^h_{i=1}\sum^{d_i}_{j=1}[\nabla \Psi(\btheta_t)]_{ij}\underbrace{\mathbbm{P}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bg]_{ij})\right)}_{:=\boldsymbol{\mathcal{W}}}
\end{align}
where the upper bound of $\boldsymbol{\mathcal{W}}$ can be quantified by using Markov’s inequality followed by Jensen’s inequality as the following:
\begin{align}
\notag
\boldsymbol{\mathcal{W}}=&\mathbbm{P}\left(\textrm{sign}([\nabla \Psi(\btheta_t)]_{ij})\neq\textrm{sign}([\bg]_{ij})\right)
\\
\le&\mathbbm{P}[|[\nabla \Psi(\btheta_t)]_{ij}-[\bg]_{ij}|>[\nabla \Psi(\btheta_t)]_{ij}]
\\
\le&\frac{\mathbb{E}[[\nabla \Psi(\btheta_t)]_{ij}-[\bg]_{ij}]}{|[\nabla \Psi(\btheta_t)]_{ij}|}
\\
\le&\frac{\sqrt{\mathbb{E}[([\nabla \Psi(\btheta_t)]_{ij}-[\bg]_{ij})^2]}}{|[\nabla \Psi(\btheta_t)]_{ij}|}
\\
\mathop{\le}\limits^{\eqref{eq.unb}}&\frac{\sqrt{\mathbb{E}[([\bar{\mathbf{g}}_t]_{ij}-[\bgo^*]_{ij} + [\bgo^*]_{ij}- [\bgo]_{ij}+[\bgo]_{ij}-[\bg]_{ij})^2]}}{|[\nabla \Psi(\btheta_t)]_{ij}|}
\\
\mathop{\le}\limits^{(a)}&\sqrt{3}\frac{\sqrt{\frac{(1+\lambda)\sigma^2_{ij}}{M|\mathcal{B}|}+\epsilon_{ij}+\delta^2_{ij}}}{|[\nabla \Psi(\btheta_t)]_{ij}|} \label{eq.bdw}
\end{align}
where $(a)$ is true due to the following relations:
\emph{i}) from \eqref{eq.bdggp}, we have
\begin{equation}\label{eq.key}
 \mathbb{E}[([\bar{\mathbf{g}}_t]_{ij}-[\bgo^*]_{ij})^2] \le\frac{(1+\lambda)\sigma^2_{ij}}{MB};
\end{equation}
\emph{ii}) from \eqref{eq.ue}, we can get
\begin{equation}
  \mathbb{E}[([\bgo]_{ij}-[\bgo^*]_{ij})^2]\le \varepsilon_{ij};
\end{equation}
and \emph{iii}) from \eqref{eq.bdq}, we know
\begin{equation}
    \mathbb{E}[([\bg]_{ij}-[\bgo]_{ij})^2]\le\delta^2_{ij}.
\end{equation}



Therefore, combining \eqref{eq.Phibd} with the upper bound of $\boldsymbol{\mathcal{U}}$ shown in \eqref{eq.ubd} and $\boldsymbol{\mathcal{V}}$ shown in \eqref{eq.vbd}\eqref{eq.bdw}, we have
\begin{align}\notag
\mathbb{E}[\Psi(\btheta_{t+1})]\le& \mathbb{E}[\Psi(\btheta_t)]-\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\frac{1}{4}\mathbb{E}\|\nabla \Psi(\btheta_t)\|^2 +\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)
\\
&+\eta_t c_l\sqrt{\frac{h(1-\beta_2)}{G^2d}}\delta^2+\eta_t c_u\sqrt{3}\sum^h_{i=1}\sum^{d_i}_{j=1}\sqrt{\frac{(1+\lambda)\sigma^2_{ij}}{MB}+\varepsilon_{ij}+\delta^2_{ij}}+\frac{\eta^2_t c^2_u \sum^h_{i=1}L_i}{2}.
\end{align}

Note that the error vector $\boldsymbol{\chi}$ is defined as the following
\begin{equation}
    \boldsymbol{\chi}=\left[\begin{matrix}\sqrt{\frac{(1+\lambda)\sigma^2_{11}}{M|\mathcal{B}|}+\varepsilon_{11}+\delta^2_{11}} \\
    \vdots \\
    \sqrt{\frac{(1+\lambda)\sigma^2_{ij}}{M|\mathcal{B}|}+\varepsilon_{ij}+\delta^2_{ij}}
    \\
    \vdots \\
    \sqrt{\frac{(1+\lambda)\sigma^2_{hd_h}}{M|\mathcal{B}|}+\varepsilon_{hd_h}+\delta^2_{hd_h}}
    \end{matrix}\right]\in\mathbb{R}^{d},
\end{equation}
and we have
\begin{equation}
    L=\left[\begin{matrix}L_1\\ \vdots \\ L_h\end{matrix}\right]\in\mathbb{R}^h.
\end{equation}


Recall
\begin{equation}
\kappa=\frac{c_u}{c_l}.
\end{equation}
Rearranging the terms, we can arrive at
\begin{align}\notag
\underbrace{\sqrt{\frac{h(1-\beta_2)}{G^2d}}\frac{1}{4}}_{:=C}\left(\|\nabla \Psi(\btheta_t)\|^2\right)\le & \frac{\mathbb{E}[\Psi(\btheta_t)]-\mathbb{E}[\Psi(\btheta_{t+1})]}{\eta_t c_l}+ 4C\delta^2+ 2C\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)
\\
&+\sqrt{3}\kappa\|\boldsymbol{\chi}\|_1+\frac{\eta_t c_u\kappa \|L\|_1}{2}.
\end{align}

Applying the telescoping  sum over $t=1,\ldots,T$, we have
\begin{align}\notag
\frac{1}{T}\sum^{\top}_{t=1}\mathbb{E}\|\nabla_{\btheta} \Psi(\btheta_t)\|^2\le&\frac{\mathbb{E}[\Psi(\btheta_1)]-\mathbb{E}[\Psi(\btheta_{T+1})]}{\eta_t c_l CT}+2\left(\varepsilon+\frac{(1+\lambda)\sigma^2}{MB}\right)+4\delta^2
\\
&+\frac{\kappa\sqrt{3}}{C}\|\boldsymbol{\chi}\|_1+\frac{\eta_t c_u\kappa \|L\|_1}{2C}.
\end{align}
%When batch size is large, then $\|\sigma\|$ is small. So when we choose $\eta_t=\sqrt{1/T}$, we will have
%\begin{equation}
%\frac{1}{T}\sum^{\top}_{t=1}\mathbb{E}\|\nabla_{\btheta} \Psi(\btheta_t)\|^2\le\mathcal{O}\left(\frac{1}{\sqrt{T}}\right)+\mathcal{O}(\delta^2+\delta).
%\end{equation}
%Also, when the levels of the quantization are large, then $\delta^2$ is also small.
\end{proof}




\section{Additional Experiments}\label{app: sect}
% \subsection{Discussion on cyclic learning rate}
%  \label{app: CLR}
%  \SL{[Check if you still need it if Figure 1 is presented.]}
% It was shown in \citep{Wong2020Fast} that the use of a cyclic learning rate (CLR) trick can further accelerate  the Fast  AT algorithm in the small-batch   setting \citep{Wong2020Fast}. In Figure\,\ref{fig: cyclic_lr_batch_size}, we present the performance of Fast AT with CLR versus batch sizes.
% We observe that when CLR meets the large-batch setting, it becomes significantly worse than its performance in the small-batch setting. The   reason is that  CLR requires a certain number of iterations
% to proceed with the cyclic schedule. However, the use of large data batch  only results in  a small amount of iterations by fixing   the number of epochs. 
% % However, such a trick   becomes less effective
% % when the batch size becomes   larger (namely, the number of iterations gets smaller); see \SL{Appendix\,xxx}. Meanwhile,  the sensitivity of  adversarially model training to   step size  can be mitigated 
% % by using   early-stop  remedy  due to the existence of   robust overfitting  %in adversarially model training  
% % \citep{rice2020overfitting}. Spurred by that, we   use  the standard piecewise decay step size   and an early-stop strategy during robust training.

%  \begin{figure}[htb]
%     \vspace*{-0.0in}
% \centerline{
% \begin{tabular}{c}
% \includegraphics[width=.5\textwidth,height=!]{Figures/app.pdf} 
% %&
% %\includegraphics[width=.5\textwidth,height=!]{Figures/pltfigure21.pdf}
% %\\
% %\footnotesize{(a)} &   \footnotesize{(b)}
% \end{tabular}}
% \caption{\footnotesize{
% TA/RA of Fast AT with CLR versus batch sizes.
% }}
%   \label{fig: cyclic_lr_batch_size}
%   \vspace*{-0.00in}
% \end{figure}


\subsection{Training details}
\label{app: train_setting}
% CIFAR-10 AT and Fast AT experiments are conducted at a single computing node with 16-core CPU, 128GB RAM and 1 Nvidia P100 GPU. The training epoch is $100$ by calling for the momentum SGD optimizer. The weight decay and momentum parameters are set to $0.0005$ and $0.9$. 
% The initial learning rate  is set with
% $0.05$ (tuned over $\{0.005, 0.01, 0.05, 0.1\}$), which is
% %\{$0.01$, $0.05$\} 
% decayed by $\times 1/10$ at the training epoch $70$, $85$ and $95$, respectively.

% CIFAR-10 DAT experiments are conducted at $\{1, 6, 12, 18\}$ computing nodes with 16-core CPU, 128GB RAM and 1 Nvidia P100 GPU. The training epoch is $100$ by calling for the LAMB optimizer. The weight decay is set to $0.0005$.  $\beta_1$ and $\beta_2$ are set to $0.9$ and $0.999$. The initial learning rate  is tuned over $\{0.01, 0.05, 0.1, 0.2 \}$,
% %\{$0.01$, $0.05$, $0.1$, $0.2$\} 
% which is decayed by $\times 1/10$ at the training epoch $70$, $85$ and $95$, respectively.
% To execute algorithms with the initial learning rate $\eta_1$ greater than $0.1$, we choose the   model weights after $10$-epoch warm-up as its initialization for DAT, where each warm-up epoch $k$ uses the linearly increased learning rate $(k/10)\eta_1$. 

% If the learning rate is larger than $0.1$, we use 10 epochs as warm-up.
% \SL{[I did not follow the last sentence.]}

ImageNet AT and Fast AT experiments are conducted at a single computing node with dual 22-core CPU, 512GB RAM and 6 Nvidia V100 GPUs. The training epoch is $30$ by calling for the momentum SGD optimizer. The weight decay and momentum parameters are set to $0.0001$ and $0.9$. The initial learning rate is set to $0.1$ (tuned over $\{0.01, 0.05, 0.1, 0.2\}$), which is
decayed by $\times 1/10$ at the training epoch $20, 25, 28$, respectively.

ImageNet DAT experiments are conducted at % \{1, 3, 6\}
$\{ 1,3,6\}$
computing nodes with dual 22-core CPU, 512GB RAM and 6 Nvidia V100 GPUs. The training epoch is $30$ by calling for the  LAMB optimizer. The weight decay is set to $0.0001$. $\beta_1$ and $\beta_2$ are set to $0.9$ and $0.999$. The initial learning rate $\eta_1$  is tuned over \{$0.01$, $0.05$, $0.1$, $0.2$, $0.4$\}, which is decayed by $\times 1/10$ at the training epoch $20, 25, 28$, respectively. To execute algorithms with the initial learning rate $\eta_1$ greater than $0.2$, we choose the   model weights after $5$-epoch warm-up as its initialization for DAT, where each warm-up epoch $k$ uses the linearly increased learning rate $(k/5)\eta_1$. 

\subsection{Additional results}
\label{app: add_results}

\paragraph{Discussion on cyclic learning rate.}
 %\label{app: CLR}
It was shown in \citep{Wong2020Fast} that the use of a cyclic learning rate (CLR) trick can further accelerate  the Fast  AT algorithm in the small-batch   setting \citep{Wong2020Fast}. In Figure\,\ref{fig: cyclic_lr_batch_size}, we present the performance of Fast AT with CLR versus batch sizes.
We observe that when CLR meets the large-batch setting, it becomes significantly worse than its performance in the small-batch setting. The   reason is that  CLR requires a certain number of iterations
to proceed with the cyclic schedule. However, the use of large data batch  only results in  a small amount of iterations by fixing   the number of epochs. 
% However, such a trick   becomes less effective
% when the batch size becomes   larger (namely, the number of iterations gets smaller); see \SL{Appendix\,xxx}. Meanwhile,  the sensitivity of  adversarially model training to   step size  can be mitigated 
% by using   early-stop  remedy  due to the existence of   robust overfitting  %in adversarially model training  
% \citep{rice2020overfitting}. Spurred by that, we   use  the standard piecewise decay step size   and an early-stop strategy during robust training.

 \begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{c}
\includegraphics[width=.5\textwidth,height=!]{Figures/app.pdf} 
%&
%\includegraphics[width=.5\textwidth,height=!]{Figures/pltfigure21.pdf}
%\\
%\footnotesize{(a)} &   \footnotesize{(b)}
\end{tabular}}
\caption{\footnotesize{
TA/RA of Fast AT with CLR versus batch sizes on (CIFAR-10, ResNet-18).
}}
  \label{fig: cyclic_lr_batch_size}
  \vspace*{-0.00in}
\end{figure}

\mycomment{
\paragraph{{Empirical model convergence.}}
In Figure\,\ref{fig: loss_supp}, we present the training accuracy and the loss value of  DAT-PGD. As we can see, our proposal converges well within $100$ and $30$ epochs in the setting of 
(CIFAR-10, ResNet-18) and (ImageNet, ResNet-50), respectively

\begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{cc}
\includegraphics[width=.45\textwidth,height=!]{Figures/loss-cifar.pdf}  &
\includegraphics[width=.45\textwidth,height=!]{Figures/loss-imagenet.pdf}
\\
\footnotesize{(a)  CIFAR-10, ResNet-18} &   \footnotesize{(b) ImageNet, ResNet-50}
\end{tabular}}
\caption{\footnotesize{Training accuracy and objective value (loss) of   DAT-PGD against training epochs.
(a) DAT-PGD for (CIFAR-10, ResNet-18) using $6 \times 1$ computing configuration and $6 \times 2048$ batch size. (b) DAT-PGD for (ImageNet, ResNet-50) using $6 \times 6$ computing configuration and $6 \times 512$ batch size. 
% following the setting in Table\,\ref{table: overall}.
%   %using $6 \times 1$ computing resources and $2048 \times 6$ batch size.
% (Left) Training accuracy and loss against training epochs  for CIFAR-10, ResNet-18. (Right) Training accuracy and loss against training epochs  for ImageNet, ResNet-50.
%for training CIFAR-10 on ResNet-18 with different numbers of computing nodes. The batch size of each node is 2048 so the total batch size will be $(\text{\# of nodes}) \times 2048$. 
%(a) Fine-tuning over CIFAR-10. (b):  Fine-tuning over CIFAR-100.
% pre-trained xxx \SL{[model name]} over dataset $\mathcal A$ using DAT. Left: RA against PGD attacks of different perturbation sizes during testing. Right:  RA against PGD attacks of different steps during testing.
}}
  \label{fig: loss_supp}
  \vspace*{-0.00in}
\end{figure}


%\begin{wraptable}{r}{80mm}
%\vspace*{-0.5in}
\paragraph{{Tuning LALR hyperparameter $c_u$.}}
{We also evaluate the sensitivity of the performance of DAT to the choice of the  hyperparameter $c_u$ in LALR.  In Table\,\ref{table: LALR_hyper_cu}, we fix $c_l = 0$ (this is a natural choice) but varies $c_u \in \{ 8, 9, 10, 11, 12 \}$ when DAT-FGSM is executed under CIFAR-10 using $18 x 2048$ batch size, where $c_u = 10$ is our default choice.  As we can see, both RA and TA are not quite sensitive to $c_u$ and the default choice yields the RA-best model (in spite of minor  improvement). 
}


\begin{table}[htb]
\begin{center}
\caption{\footnotesize{
TA/RA of DAT-FGSM under   (CIFAR-10, ResNet-18) using $18 \times 2048$ batch size versus different choices of   $c_u$.
}
} 
\label{table: LALR_hyper_cu}
\begin{threeparttable}
\resizebox{0.3\textwidth}{!}{
\begin{tabular}{c|c|c}
\hline
\hline
Value of $ c_u $  &  TA (\%) & RA  (\%)
 \\ \hline
$ c_u = 8 $  
&  73.57    & 38.19   \\
$ c_u = 9 $  
& 73.72 & 38.00  \\
$ c_u = 10 $ 
& 73.42 &  38.55  \\
$ c_u = 11 $  
& 73.75 & 38.18 
  \\
$ c_u = 12 $  
& 73,63  &  37.87
 \\
\hline
\hline 
\end{tabular}}
\end{threeparttable}
\end{center}
\vspace{-3mm}
%\end{wraptable}
\end{table}
}

\paragraph{Additional details on HPC setups.}
To further reduce communication cost, we also conduct DAT at a HPC cluster. The computing nodes of the cluster are connected with InfiniBand (IB) and PCIe Gen4 switch. To compare with results in Table\,\ref{table: overall}, we use 6 of 57 nodes of the cluster. Each node has 6 Nvidia V100s which are interconnected with NVLink. We use Nvida NCCL as communication backend. In Table\,\ref{table: quadtization_imagenet}, we have presented the performance of DAT for ImageNet, ResNet-50 with use of HPC compared to standard (non-HPC) distributed system.  

\iffalse
%\begin{wraptable}{r}{80mm}
\begin{table}[htb]
%\vspace*{-7mm}
\begin{center}
\caption{\small{DAT  in semi-supervised learning  under ResNet-18  with batch size $128$. 
%The relative improvement over RA  or TA obtained in    supervised learning (CIFAR-10 only) is marked by \textcolor{red}{red} color.
%\SL{RA}.
%\Gaoyuan{unlabeled data from tiny imagenet. obtained from } 
}
%Unlabeled data,  GPUs per node $\times$ nodes: $1\times 6$, batch size $12288$
} 
% \vspace{-2.5mm}
\label{table: unlabel_app}
\begin{threeparttable}
\resizebox{0.55\textwidth}{!}{
\begin{tabular}{c|c|c|c|c|c}
\hline
\hline
\multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{5}{c}{\textbf{CIFAR-10 + 500K unlabeled Tiny Images,   ResNet-18}} \\ 
\cline{2-6}  
%& \begin{tabular}[c]{@{}c@{}}GPUs\end{tabular}
%& Batch size
& TA (\%) & RA (\%) & AA (\%) & \begin{tabular}[c]{@{}c@{}}Comm. %\\per epoch (s)
\end{tabular} & \begin{tabular}[c]{@{}c@{}}Tr. time (s)
%\\per epoch (s)
\end{tabular}
 \\ \hline
% DAT (CIFAR-10 only) 
% %& $1\times 6$  & 12288 
% & 79.77 & 38.93 & 8.5 & 42  \\
% Fast DAT (CIFAR-10 only) 
% %& $1\times 6$  & 12288 
%  & 75.58 & 40.91 & 8.5  & 14  \\
DAT-PGD  %& $1\times 6$  & 12288 
& 90.21 % (\textcolor{red}{$\uparrow$ 7.62})
& 55.89 % (\textcolor{red}{$\uparrow$ 8.40}) 
& 45.23 %
& 0  & 1266\\
DAT-FGSM 
%& $1\times 6$  & 12288 
& 90.73 % (\textcolor{red}{$\uparrow$ 12.42}) 
& 52.39 % (\textcolor{red}{$\uparrow$ 4.92})
& 43.19%
& 0  & 553\\
\hline
\hline
\end{tabular}
}
\end{threeparttable}
\end{center}
%\vspace*{-5mm}
%\end{wraptable}
\end{table}
\fi 




\mycomment{
\subsection{Robust accuracy (RA) versus number of computing nodes}
\label{app:RA_nodes}
\SL{Figure\,\ref{fig: acc_nodes} presents xxx} 

\begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{cc}
\includegraphics[width=.5\textwidth,height=!]{Figures/pltfigure4.pdf} 
%\\
%\footnotesize{(a)} &   \footnotesize{(b)}
\end{tabular}}
\caption{\footnotesize{RA for training CIFAR-10 on ResNet-18 with different numbers of computing nodes. The batch size of each node is 2048 so the total batch size will be $(\text{\# of nodes}) \times 2048$. \SL{[updated?]}
%(a) Fine-tuning over CIFAR-10. (b):  Fine-tuning over CIFAR-100.
% pre-trained xxx \SL{[model name]} over dataset $\mathcal A$ using DAT. Left: RA against PGD attacks of different perturbation sizes during testing. Right:  RA against PGD attacks of different steps during testing.
}}
  \label{fig: acc_nodes}
  \vspace*{-0.00in}
\end{figure}




}


\iffalse
\subsection{Experiment results on CIFAR-10}
\label{app: all_cifar-10}

\subsubsection{Overall performance}
\label{app: overall-cifar10}


In Table\,\ref{table: overall_supplement}, we observe that 
in the large-batch setting, the proposed DAT-PGD and DAT-FGSM algorithms outperform   the baseline algorithms, and result in competitive performance to centralized AT and Fast AT, which call for more iterations by using a smaller batch size.


%\begin{wraptable}{r}{95mm}
% \vspace{-9mm}
\begin{table}[htb]
\begin{center}
\caption{\small{Overall performance of DAT  (in gray color), compared with baselines, in TA (\%), RA (\%), communication time per epoch (seconds), and total training time (including communication time)  per epoch (seconds). For brevity, `$p \times q$' represents `\# nodes $\times$  \# GPUs per node', 
`Comm.' represents communication cost, and `Tr. Time' represents training time.
\mycomment{In the columns `TA' and `RA', we present the relative \increase{improvement} (\%)  or \decrease{degradation} (\%) upon   the performance of AT (first row). In the columns `Comm.' and `Tr. time', we {highlight} the \high{worst} communication and training time used in  \textit{distributed} setting.} 
%SL{we need to show relative improvement compared to AT and Fast AT, like fine-tuning example.}
}
%All the training methods are conducted in \SL{xxx [@Gaoyuan]} epochs .
} 
% \vspace{-2.5mm}
\label{table: overall_supplement}
\begin{threeparttable}
\resizebox{0.68\textwidth}{!}{
\begin{tabular}{c|c|c|c|c|c|c}
\hline
\hline
\multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-18}} \\ 
\cline{2-7}  & \begin{tabular}[c]{@{}c@{}}
$p \times q$
%Nodes $\times$\\GPUs per node
\end{tabular}
& Batch size
& TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Comm. 
%\\per epoch (s)
\end{tabular} & \begin{tabular}[c]{@{}c@{}}
Tr. time
%Training time\\per epoch (s)
\end{tabular}
 \\ \hline
AT & $1\times 1$ & $2048$ & 82.94 & 38.54 & NA & {218}  \\
{Fast AT} & $1\times 1$  & $2048$ & 81.58 % \decrease{(1.36)}
& 38.34  %\decrease{(0.20)}
& NA & 52  \\
\mycomment{ 
{DAT-PGD w/o LALR} & $6\times 1$  & {$2048$} & 82.02 \decrease{(0.92)}  & 38.18 \decrease{(0.36)} & \high{40.9}  & \high{87}\\
}
\mycomment{
DAT-PGD w/o LALR & $6\times 1$  & $6 \times 2048$ &  69.30  \decrease{(13.64)} %% 74.29
& 33.86 \decrease{(4.68)} & 8.5  & 42\\}
DAT-PGD w/o LALR & $18\times 1$  & $18 \times 2048$ &  55.59   %% 74.29
& 26.83    & 3.4  & 22\\
\mycomment{
DAT-FGSM w/o LALR & $6\times 1$  & $6 \times 2048$ & 64.46 \decrease{(18.48)} & 33.96 \decrease{(4.58)}  & 8.5  & 14\\}
DAT-FGSM w/o LALR & $18\times 1$  & $18 \times 2048$ & 52.35   & 28.90 
& 3.1  & 8\\
{DAT-LSGD} & $18\times 1$  &   $18 \times 2048$ & 64.15  & 34.12  
& 3.2 & 22  \\
\mycomment{\rowcolor{Gray}
{DAT-PGD} & $6\times 1$  & $6 \times 2048$ & 80.38 \decrease{(2.56)}  & 38.94 \increase{(0.40)}  & 8.5  & 42\\}
\rowcolor{Gray}
{DAT-PGD} & $18\times 1$  & $18 \times 2048$ & 80.28   & 38.44
& 3.4  & 22\\
\mycomment{\rowcolor{Gray}
{DAT-FGSM} & $6\times 1$  & $6 \times 2048$ & 75.58   \decrease{(7.36)} & 40.92
\increase{(2.38)}
& 8.5  & 14\\}
\rowcolor{Gray}
{DAT-FGSM} & $18\times 1$  & $18 \times 2048$ & 73.42  & 38.55
& 3.1  & 8\\
\rowcolor{Gray}
{DAT-FGSM} & $24\times 1$  & $24 \times 2048$ & 72.76 & 39.82
& 2.0  & 5\\
% \hline
%   & \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-50} \SL{Appendix}} 
% %   \\ 
% % \cline{2-7}  &  \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
% % & Batch size
% % & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Communication\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Computation\\per epoch (s)\end{tabular}
%  \\ \hline
% AT & $1\times 1$  & $256$ &85.94   & 43.06 & NA &  894  \\
% {Fast AT} &$1\times 1$ & $256$ & 75.28 & 40.48 & NA & 288  \\
% DAT-PGD w/o LALR & $6\times 1$ & $6 \times 256$ & 74.45 & 33.35 & 68  & 236\\
% \rowcolor{Gray}
% DAT-PGD & $6\times 1$  & $6 \times 256$ & 84.79 & 42.16 & 68  & 236\\
% \rowcolor{Gray}
% DAT-FGSM & $6\times 1$  & $6 \times 256$ & 75.72 & 40.09 & 68  & 116\\
% \hline
%   & \multicolumn{6}{c}{\textbf{ImageNet, ResNet-50}} \\
%   \hline
%   AT & $1\times 6$  & $512$ & 62.70 & 40.38 & NA & 6022  \\
% {Fast AT} & $1\times 6$  & $512$ & 58.99  % \decrease{(3.71)}
% & 40.78 % \increase{(0.4)}
% & NA & 1544  \\
% \mycomment{{DAT-PGD w/o LALR} & $6\times 6$  & $512$ & 62.36 %\decrease{(0.34)}
% %65.06 %\SL{check?} 
% &  %39.28
% 39.86 % \decrease{(0.52)}
% & \high{4324}  & \high{5663}\\}
% DAT-PGD w/o LALR & $6\times 6$  & $6 \times 512$ & 
% 57.09   %\decrease{(5.61)}
% %60.09 
% %\SL{check?} 
% &  34.02 %35.02 
% %\decrease{(6.36)}
% & 865  & 1932\\
% {DAT-FGSM w/o LALR} & $6\times 6$  & $6 \times 512$ & 55.04 %\decrease{(7.66)}
% %57.04 \SL{check?} 
% &  35.03 %39.03 \SL{check?} 
% %\decrease{(5.35)}
% & 863  & 1080\\
% \rowcolor{Gray}
% {DAT-PGD} & $6\times 6$  & $6 \times 512$ & 63.75 %\increase{(1.05)}
% & 38.45 
% %\decrease{(1.93)}
% & 898  & 1960\\
% \rowcolor{Gray}
% {DAT-FGSM} & $6\times 6$  & $6 \times 512$ & 58.32  %\decrease{(4.38)}
% & 41.48
% %\increase{(1.1)}
% & 859  & 1109\\
\hline
& \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-50}} %\\ 
% \cline{2-7}  & \begin{tabular}[c]{@{}c@{}}$p \times q$\end{tabular}
% & Batch size
% & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Comm.\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Tr. Time\\per epoch (s)\end{tabular}
 \\ \hline
% AT & $1\times 1$ & $2048$ & 82.94 & 38.54 & NA & {218}  \\
% {Fast AT} & $1\times 1$  & $2048$ & 81.58 & 38.33 & NA & 52  \\
% \SL{DAT-PGD w/o LALR} & $6\times 1$  & \SL{$2048$} & 83.01 & 38.18 & 23.9  & 70\\
% DAT-PGD w/o LALR & $6\times 1$  & $6 \times 2048$ &  69.29 %% 74.29
% & 33.85 & 8.5  & 42\\
% DAT-FGSM w/o LALR & $6\times 1$  & $6 \times 2048$ & 64.45 & 34.95 & 8.5  & 14\\
% \rowcolor{Gray}
% {DAT-PGD} & $6\times 1$  & $6 \times 2048$ & 79.77 & 38.93 & 8.5  & 42\\
% \rowcolor{Gray}
% {DAT-PGD} & $18\times 1$  & $18 \times 2048$ & 80.27 & 38.43 & 3.4  & 22\\
% \rowcolor{Gray}
% {DAT-FGSM} & $6\times 1$  & $6 \times 2048$ & 75.58 & 40.91 & 8.5  & 14\\
% \hline
%   & \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-50} \SL{Appendix}} 
%   \\ 
% \cline{2-7}  &  \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
% & Batch size
% & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Communication\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Computation\\per epoch (s)\end{tabular}
 %\\ \hline
AT & $1\times 1$  & $256$ &85.94   & 43.06 & NA &  894  \\
{Fast AT} &$1\times 1$ & $256$ & 75.28 & 40.48 & NA & 288  \\
DAT-PGD w/o LALR & $6\times 1$ & $6 \times 256$ & 74.45 & 33.35 & 68  & 236\\
\rowcolor{Gray}
DAT-PGD & $6\times 1$  & $6 \times 256$ & 84.79 & 42.16 & 68  & 236\\
\rowcolor{Gray}
DAT-FGSM & $6\times 1$  & $6 \times 256$ & 75.72 & 40.09 & 68  & 116\\
% \hline
%   & \multicolumn{6}{c}{\textbf{ImageNet, ResNet-50}} \\
%   \hline
%   AT & $1\times 6$  & $512$ & 62.70 & 40.38 & NA & 6022  \\
% {Fast AT} & $1\times 6$  & $512$ & 58.99 & 41.78 & NA & xxx  \\
% \SL{DAT-PGD w/o LALR} & $6\times 6$  & $512$ &65.06 & 39.28 & xxx  & xxx\\
% DAT-PGD w/o LALR & $6\times 6$  & $6 \times 512$ & 60.09 & 35.02 & 42  & 2121\\
% \SL{DAT-FGSM w/o LALR} & $6\times 6$  & $6 \times 512$ &57.04 & 39.03 & xxx  & xxx\\
% \rowcolor{Gray}
% {DAT-PGD} & $6\times 6$  & $6 \times 512$ & 63.75 & 38.45 & xxx  & xxx\\
% \rowcolor{Gray}
% {DAT-FGSM} & $6\times 6$  & $6 \times 512$ & 58.32 & 41.48 & 42  & 1134\\
\hline
\hline
\end{tabular}}
\end{threeparttable}
\end{center}
%\vspace{-5mm}
%\end{wraptable}
\end{table}


% \begin{table}[ht]
% \begin{center}
% \caption{\small{Overall performance of DAT  (in gray color), compared with baselines, in TA (\%), RA (\%), communication time per epoch (seconds), and total training time (including communication time)  per epoch (in seconds). For brevity, `$p \times q$' represents `\# nodes $\times$  \# GPUs per node', 
% `Comm.' represents communication cost, and `Tr. Time' represents training time.
% %In the columns `TA' and `RA', we present the relative \increase{improvement} (\%)  or \decrease{degradation} (\%) upon   the performance of AT (first row). In the columns `Comm.' and `Tr. time', we {highlight} the \high{worst} communication and training time used in  \textit{distributed} setting.
% }
% %All the training methods are conducted in \SL{xxx [@Gaoyuan]} epochs .
% } 
% % \vspace{-2.5mm}
% \label{table: overall_supplement}
% \begin{threeparttable}
% \resizebox{1\textwidth}{!}{
% \begin{tabular}{c|c|c|c|c|c|c}
% \hline
% \hline
% \multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-50}} \\ 
% \cline{2-7}  & \begin{tabular}[c]{@{}c@{}}$p \times q$\end{tabular}
% & Batch size
% & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Comm.\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Tr. Time\\per epoch (s)\end{tabular}
%  \\ \hline
% % AT & $1\times 1$ & $2048$ & 82.94 & 38.54 & NA & {218}  \\
% % {Fast AT} & $1\times 1$  & $2048$ & 81.58 & 38.33 & NA & 52  \\
% % \SL{DAT-PGD w/o LALR} & $6\times 1$  & \SL{$2048$} & 83.01 & 38.18 & 23.9  & 70\\
% % DAT-PGD w/o LALR & $6\times 1$  & $6 \times 2048$ &  69.29 %% 74.29
% % & 33.85 & 8.5  & 42\\
% % DAT-FGSM w/o LALR & $6\times 1$  & $6 \times 2048$ & 64.45 & 34.95 & 8.5  & 14\\
% % \rowcolor{Gray}
% % {DAT-PGD} & $6\times 1$  & $6 \times 2048$ & 79.77 & 38.93 & 8.5  & 42\\
% % \rowcolor{Gray}
% % {DAT-PGD} & $18\times 1$  & $18 \times 2048$ & 80.27 & 38.43 & 3.4  & 22\\
% % \rowcolor{Gray}
% % {DAT-FGSM} & $6\times 1$  & $6 \times 2048$ & 75.58 & 40.91 & 8.5  & 14\\
% % \hline
% %   & \multicolumn{6}{c}{\textbf{CIFAR-10, ResNet-50} \SL{Appendix}} 
% %   \\ 
% % \cline{2-7}  &  \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
% % & Batch size
% % & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Communication\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Computation\\per epoch (s)\end{tabular}
%  %\\ \hline
% AT & $1\times 1$  & $256$ &85.94   & 43.06 & NA &  894  \\
% {Fast AT} &$1\times 1$ & $256$ & 75.28 & 40.48 & NA & 288  \\
% DAT-PGD w/o LALR & $6\times 1$ & $6 \times 256$ & 74.45 & 33.35 & 68  & 236\\
% \rowcolor{Gray}
% DAT-PGD & $6\times 1$  & $6 \times 256$ & 84.79 & 42.16 & 68  & 236\\
% \rowcolor{Gray}
% DAT-FGSM & $6\times 1$  & $6 \times 256$ & 75.72 & 40.09 & 68  & 116\\
% % \hline
% %   & \multicolumn{6}{c}{\textbf{ImageNet, ResNet-50}} \\
% %   \hline
% %   AT & $1\times 6$  & $512$ & 62.70 & 40.38 & NA & 6022  \\
% % {Fast AT} & $1\times 6$  & $512$ & 58.99 & 41.78 & NA & xxx  \\
% % \SL{DAT-PGD w/o LALR} & $6\times 6$  & $512$ &65.06 & 39.28 & xxx  & xxx\\
% % DAT-PGD w/o LALR & $6\times 6$  & $6 \times 512$ & 60.09 & 35.02 & 42  & 2121\\
% % \SL{DAT-FGSM w/o LALR} & $6\times 6$  & $6 \times 512$ &57.04 & 39.03 & xxx  & xxx\\
% % \rowcolor{Gray}
% % {DAT-PGD} & $6\times 6$  & $6 \times 512$ & 63.75 & 38.45 & xxx  & xxx\\
% % \rowcolor{Gray}
% % {DAT-FGSM} & $6\times 6$  & $6 \times 512$ & 58.32 & 41.48 & 42  & 1134\\
% \hline
% \hline
% \end{tabular}}
% \end{threeparttable}
% \end{center}
% \vspace{-3mm}
% \end{table}




\subsubsection{Robustness against PGD and C\&W attacks}\label{app:RA_pgd}
In Figure\,\ref{fig: robust_PGD_supplement}, we evaluate the adversarial robustness of ResNet-18 at CIFAR-10 learned by DAT-PGD and DAT-FGSM against   PGD attacks of different steps
and perturbation sizes (namely, values of $\epsilon$).
We consistently observe that DAT matches  robust accuracies of standard AT even against PGD attacks at different values of $\epsilon$ and steps. Specifically,  DAT has slightly smaller   RA than AT when facing  weak  PGD attacks with $\epsilon$ less than $(5/255)$ and steps less than $5$. Moreover, although DAT-FGSM has the worst RA  against weak PGD attacks (which reduces to TA at $\epsilon = 0$), it outperforms other methods when the attacks become stronger in CIFAR-10 experiments. 
In Figure\,\ref{fig: acc_CW}, we present the additional robust accuracies against C\&W attacks \citep{carlini2017towards} of different perturbation sizes. As we can see, the results are consistent with the aforementioned ones against PGD attacks. 


   \begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{cc}
\includegraphics[width=.45\textwidth,height=!]{Figures/pltfigure10.pdf}  &
\includegraphics[width=.45\textwidth,height=!]{Figures/pltfigure11.pdf}
%\\
%\footnotesize{(a)} &   \footnotesize{(b)}
\end{tabular}}
\caption{\footnotesize{RA against different
PGD attacks for   the model
 trained by DAT-PGD,     DAT-FGSM, and AT  under (CIFAR-10, ResNet-18).
  %using $6 \times 1$ computing resources and $2048 \times 6$ batch size.
(Left) RA against PGD attacks with different perturbation sizes (over the divisor $255$). (Right)  RA against PGD attacks with different steps.
%\SL{Update legends: AT, DAT-PGD, DAT-FGSM}
}}
  \label{fig: robust_PGD_supplement}
   \vspace*{-0.00in}
\end{figure}


\begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{cc}
\includegraphics[width=.45\textwidth,height=!]{Figures/pltfigure10-cw.pdf}  &
\includegraphics[width=.45\textwidth,height=!]{Figures/pltfigure10imagenet-cw.pdf}
\\
\footnotesize{(a)  CIFAR-10, ResNet-18} &   \footnotesize{(b) ImageNet, ResNet-50}
\end{tabular}}
\caption{\footnotesize{RA against different
C\&W attacks for the model
 trained by DAT-PGD,     DAT-FGSM, and AT  
 under the setting (CIFAR-10, ResNet-18) and 
 (ImageNet, ResNet-50), respectively.
%  following the setting as 8th and 17th row in Table\,\ref{table: overall}.
%   %using $6 \times 1$ computing resources and $2048 \times 6$ batch size.
% (Left) RA against PGD attacks with different perturbation sizes (over the divisor $255$) for CIFAR-10, ResNet-18. (Right) RA against PGD attacks with different perturbation sizes (over the divisor $255$) for ImageNet, ResNet-50.
Here for ease of C\&W attack generation at ImageNet, we randomly select 1000 test  ImageNet images (1 image per class)
to generate C\&W attacks.
%for training CIFAR-10 on ResNet-18 with different numbers of computing nodes. The batch size of each node is 2048 so the total batch size will be $(\text{\# of nodes}) \times 2048$. 
%(a) Fine-tuning over CIFAR-10. (b):  Fine-tuning over CIFAR-100.
% pre-trained xxx \SL{[model name]} over dataset $\mathcal A$ using DAT. Left: RA against PGD attacks of different perturbation sizes during testing. Right:  RA against PGD attacks of different steps during testing.
}}
  \label{fig: acc_CW}
  \vspace*{-0.00in}
\end{figure}




\subsubsection{DAT from pre-training to fine-tuning}
\label{app:dat_pretrain}
In Figure\,\ref{fig: transfer_supplement}, 
we investigate if a DAT pre-trained  model (ResNet-50) over a source dataset (ImageNet)  can offer a fast fine-tuning to a down-stream target dataset (CIFAR-10).
Compared with the direct application of DAT to the target dataset (without pre-training), the pre-training enables a fast adaption  to the down-stream CIFAR-10 task in both TA and RA within just  $5$ epochs. 


  \begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{c}
\includegraphics[width=.5\textwidth,height=!]{Figures/pltfigure2.pdf} 
%&
%\includegraphics[width=.5\textwidth,height=!]{Figures/pltfigure21.pdf}
%\\
%\footnotesize{(a)} &   \footnotesize{(b)}
\end{tabular}}
\caption{\footnotesize{Fine-tuning  ResNet-50 (pre-trained on ImageNet)  under CIFAR-10.  Here DAT-PGD is   used for both  pre-training and fine-tuning at $6$ nodes with   batch size $6 \times 128$. 
% Here the fine-tuning is conducted by AT at a single computing node, and the target datasets are CIFAR-10 and CIFAR-100.
%   Accuracy (TA or RA)  of the fine-tuned model is compared to that of    the end-to-end  training at CIFAR by DAT-PGD from scratch using  the same number of epochs   as the  fine-tuning. \SL{Update legends: pre-training on ImageNet, no pre-training}
%(a) Fine-tuning over CIFAR-10. (b):  Fine-tuning over CIFAR-100.
% pre-trained xxx \SL{[model name]} over dataset $\mathcal A$ using DAT. Left: RA against PGD attacks of different perturbation sizes during testing. Right:  RA against PGD attacks of different steps during testing.
}}
  \label{fig: transfer_supplement}
  \vspace*{-0.00in}
\end{figure}

% 
% 

\subsubsection{Quantization effect}
\label{app:quantization}

In Table\,\ref{table: quadtization}, we present the performance of DAT  by making use of gradient quantization.  Two quantization scenarios are covered: 1) quantization is conducted  at each  worker (Step\,7 of Algorithm\,\ref{alg: DAT}), and 2) quantization is conducted at  both worker and server sides (Step\,7 and 10 of Algorithm\,\ref{alg: DAT}).
As we can see, when the number of bits is reduced from $32$ to $8$, the communication cost and the amount of transmitted data is saved by {2} and $4$ times, respectively. Although the use of gradient quantization introduces a performance loss to some extent, the resulting TA and RA are still comparable to the best $32$-bit case. In the worst case of CIFAR-10 ($8$-bit $2$-sided quantization), TA drops $0.91\%$ and $6.33\%$ for DAT-PGD and   DAT-FGSM, respectively. And RA drops  $4.73\%$ and $5.22\%$, respectively.  {However, 8-bit 2-sided quantization transmitted the least amount of data per iteration.}

\begin{table}[ht]
\begin{center}
\caption{\footnotesize{Effect of gradient quantization on the performance of DAT  for various numbers of bits. 
The training settings of (CIFAR-10, ResNet-18) are consistent with those in Table\,\ref{table: overall}. 
%Here the quantization operation is conducted  at each  worker except the scenario ($8$ bits, $2$ sided), which implies  Algorithm\,\ref{alg: DAT} with use of   quantization operations at both worker and server sides.
% In     (CIFAR-10, ResNet-18), $6 \times 1$ computing resources and $6 \times 2048$ batch size are used. In  (ImageNet, ResNet-50), $6 \times 6$ computing resources and $6 \times 512$ batch size are used.
}
} 
% \vspace{-2.5mm}
\label{table: quadtization}
\begin{threeparttable}
\resizebox{0.8\textwidth}{!}{
\begin{tabular}{c|c|c|c|c|c}
\hline
\hline
% \multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{6}{c}{CIFAR-10, ResNet-18} \\ 
% \cline{2-7}  & \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
% & Batch size
% & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Communication\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Computation\\per epoch (s)\end{tabular}
%  \\ \hline
% AT & $1\times 1$  & 2048 & 82.94 & 38.54 & NA & 218  \\
% DAT w/o LALR & $1\times 6$  & 12288 & 76.29 & 35.55 & 8.5  & 42\\
% DAT & $1\times 6$  & 12288 & 79.77 & 38.93 & 8.5  & 42\\
% Fast DAT & $1\times 6$  & 12288 & 75.58 & 40.91 & 8.5  & 14\\
% \hline
% \hline
\multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{5}{c}{\textbf{CIFAR-10, ResNet-18}
%\SL{Appendix}
%under $6 \times 1$ computing resources and $6 \times 2048$ batch size
}
\\ 
\cline{2-6}  
%& \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
%& Batch size
& \# bits
& TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Comm.\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Data transmitted\\ per iteration (MB)\end{tabular}
 \\ \hline
DAT-PGD & $32$  & 
%$1\times 6$   & $2048\times 6$ & 
80.38 & 38.94 & 8.5  & 1278\\
DAT-PGD & $16$ 
& 
%$1\times 6$  & $2048\times 6$ & 
79.38 & 38.32 & 8.3  & 639\\
DAT-PGD & $8$  & 
%$1\times 6$  & $2048\times 6$ & 
78.18 & 37.34 & 4.3  & 320\\
DAT-PGD & $8$  ($2$-sided) & 
%$1\times 6$  & $2048\times 6$ & 
78.86 & 34.2 & 5.0  & 107\\
DAT-FGSM &  $32$  & 
%$1\times 6$  & $2048\times 6$  &
75.58 & 40.92 & 8.5  & 1278\\
DAT-FGSM & $16$  & 
%$1\times 6$  & $2048\times 6$ & 
75.74 & 40.86 & 8.3 & 639  \\
DAT-FGSM &  $8$  & 
%$1\times 6$  & $2048\times 6$ & 
72.48 & 38.98 & 4.3 & 320  \\
DAT-FGSM & $8$ ($2$-sided) &
%$1\times 6$  & $2048\times 6$ &
69.26 & 35.34 & 5.0 & 107  \\
% \hline 
% \multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}Method\end{tabular}} & \multicolumn{5}{c}{\textbf{ImageNet, ResNet-50} 
% %under $6 \times 6$ computing resources and $6 \times 512$ batch size
% } \\ 
% \cline{2-6}  
% %&  \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
% %& Batch size
% & \# bits
% & TA (\%) & RA (\%) & \begin{tabular}[c]{@{}c@{}}Comm.\\per epoch (s)\end{tabular} & \begin{tabular}[c]{@{}c@{}}Data transmitted\\ per iteration (MB)\end{tabular}
%  \\ \hline
% DAT-PGD & $32$  % & $6\times 6$  & $512\times 6$ 
% & 63.75 & 38.45 & 898  & 2924\\
% DAT-PGD & $16$  % & $6\times 6$  & $512\times 6$
% & 61.77 & 38.40 & 850  & 1462\\
% DAT-PGD & $8$ % & $6\times 6$  & $512\times 6$ 
% & 56.53 & 37.90 & 592  & 731\\
% DAT-PGD & $8$ ($2$-sided)  % & $6\times 6$  & $512\times 6$
% & 53.09 & 34.59 & 1091  & 244\\
% DAT-FGSM & 32 %& $6\times 6$  & $512\times 6$ 
% & 58.32 & 41.48 & 859  & 2924\\
% DAT-FGSM & 16 % & $6\times 6$  & $512\times 6$
% & 54.71 & 39.29 & 849 & 1462  \\
% DAT-FGSM & 8 % & $6\times 6$ & $512\times 6$ 
% & 50.11 & 36.38 & 594 & 731  \\
% DAT-FGSM & $8$ ($2$-sided) % & $6\times 6$ & $512\times 6$ 
% & 48.27 & 33.20 & 1013 & 244  \\
\hline
\hline 
\end{tabular}}
\end{threeparttable}
\end{center}
\vspace{-3mm}
\end{table}





%: with 6 times computing resources we can get almost 6 times faster than AT in table \ref{table: overall}.

In Table\,\ref{table: centralized_8bit},
we conduct an additional experiment by integrating a centralized method with gradient quantization operation on CIFAR-10 under the batch size $2048$ and $6 \times 2048$, respectively. We specify the centralized method as Fast AT with LALR, where LALR is introduced to improve the scalability of Fast AT under the larger batch size $6 \times 2048$. Due to the centralized implementation, we only need 1-sided gradient quantization (namely, no server-worker communication is involved). 
As we can see, when the batch size $2048$ is used, Fast AT w/ LALR performs as well as Fast AT even at the presence of 8-bit gradient quantization. On the other hand, as the larger batch size $6 \times 2048$ is used, Fast AT w/ LALR can still preserve the performance at the absence of gradient quantization. By contrast, Fast AT w/ LALR at the presence of quantization encounters $6.05\%$ TA drop. This suggests that  even in the non-DAT setting, 8-bit gradient quantization hurts the performance as the batch size becomes large. Thus, in DAT it is not surprising that 8-bit quantized gradients could cause a non-trivial accuracy drop, particularly for using 2-sided gradient quantization and a much larger data batch size ($\geq 18 x 2048$ on CIFAR-10). One possible reason is that the quantization error cannot easily be mitigated as the number of iterations decreases (due to increased batch size under a fixed number of epochs). 

\mycomment{
 \begin{figure}[htb]
    \vspace*{-0.0in}
\centerline{
\begin{tabular}{c}
\includegraphics[width=.5\textwidth,height=!]{Figures/app.pdf} 

\end{tabular}}
\caption{\footnotesize{
Robust accuracy for fast adversarial training 15 epochs using different batch size with cyclic learning rate.
}}
  \label{fig: cyclic_lr_batch_size}
  \vspace*{-0.00in}
\end{figure}
}

\begin{table}[ht]
\begin{center}
\caption{\footnotesize{{Effect of 8-bit  quantization on centralized robust training Fast AT w/ LALR.}
}
} 
\label{table: centralized_8bit}
\begin{threeparttable}
\resizebox{0.7\textwidth}{!}{
\begin{tabular}{c|c|c|c|c}
\hline
\hline
\multirow{2}{*}{\begin{tabular}[c]{@{}c@{}}
Method\end{tabular}} & \multicolumn{4}{c}{\textbf{CIFAR-10, ResNet-18} 
%under $6 \times 6$ computing resources and $6 \times 512$ batch size
} \\ 
\cline{2-5}  
%&  \begin{tabular}[c]{@{}c@{}}GPUs per node\\ $\times$ nodes\end{tabular}
%& Batch size
& \begin{tabular}[c]{@{}c@{}}8-bit \\ quantization (s)\end{tabular}& 
\begin{tabular}[c]{@{}c@{}}Batch size \end{tabular}
& TA (\%) & RA (\%) 
 \\ \hline
Fast AT & No % & $6\times 6$  & $512\times 6$ 
& 2048  & 81.58    & 38.34
\\
Fast AT w/ LALR  & Yes & 2048 & 80.66 &  38.60 \\
Fast AT w/ LALR & No & 6 x 2048 &  80.08 &  38.51 \\
Fast AT w/ LALR& Yes& 6 x 2048 &  75.53 & 38.45 \\
\hline
\hline 
\end{tabular}}
\end{threeparttable}
\end{center}
\vspace{-3mm}
\end{table}
\fi 




{
{{
\bibliography{zhang_207}
}}
}

\end{document}