% \documentclass{uai2023} % for initial submission
\documentclass[accepted]{uai2023}
% after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

\usepackage[usenames,dvipsnames]{xcolor}

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

%%%%% PREAMBLE %%%%%

\newcommand{\X}{\mathcal{X}}
\newcommand{\Y}{\mathcal{Y}}
\newcommand{\Z}{\mathcal{Z}}
\newcommand{\cO}{\mathcal{O}}

\newcommand{\1}{\mathbf{1}}
\renewcommand{\>}{\rightarrow}
\newcommand{\E}{\mathbb{E}}
\newcommand{\R}{\mathbb{R}}
\newcommand{\N}{\mathbb{N}}
\renewcommand{\P}{\mathbb{P}}
\newcommand{\cL}{\mathcal{L}}
\newcommand{\cW}{\mathcal{W}}
\newcommand{\cT}{\mathcal{T}}
\newcommand{\cF}{\mathcal{F}}
\newcommand{\bc}{\mathbf{c}}
\newcommand{\boldf}{\mathbf{f}}
\newcommand{\bp}{\mathbf{p}}
\newcommand{\blambda}{{\lambda}}

\newcommand{\bV}{\mathbb{V}}
\newcommand{\cM}{\mathcal{M}}

\newcommand{\Argmax}[1]{\underset{#1}{\operatorname{argmax}}}
\newcommand{\Argmin}[1]{\underset{#1}{\operatorname{argmin}}}
\newcommand{\argmax}[1]{{\operatorname{argmax}}_{#1}}
\newcommand{\argmin}[1]{{\operatorname{argmin}}_{#1}}


\newcommand{\err}{\textrm{\textup{ERR}}}
\newcommand{\wer}{\textrm{\textup{WER}}}
\newcommand{\ber}{\textrm{\textup{BER}}}
\newcommand{\val}{\textrm{\textup{val}}}


\newcommand{\xent}{\textrm{\textup{xent}}}
\newcommand{\dis}{\textrm{\textup{dis}}}
\newcommand{\bal}{\textrm{\textup{bal}}}
\newcommand{\wc}{\textrm{\textup{wor}}}
\newcommand{\rob}{\textrm{\textup{rob}}}
\newcommand{\robd}{\textrm{\textup{rob-d}}}
\newcommand{\bald}{\textrm{\textup{bal-d}}}
\newcommand{\zo}{\textrm{\textup{0-1}}}
\newcommand{\std}{\textrm{\textup{std}}}
\newcommand{\stdd}{\textrm{\textup{std-d}}}
\newcommand{\tdf}{\textrm{\textup{tdf}}}
\newcommand{\tdfd}{\textrm{\textup{tdf-d}}}
\newcommand{\mar}{\textrm{\textup{mar}}}
\newcommand{\oh}{\textrm{\textup{oh}}}
\newcommand{\softmax}{\textrm{\textup{softmax}}}

\newcommand\AddLabel[1]{\refstepcounter{equation}(\theequation)\label{#1}}

% To control spacing in itemized lists
\usepackage{enumitem}

% Algorithm command
\usepackage{algorithm}
\usepackage{algorithmic} 

% Theorems
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{amsthm}

\theoremstyle{plain}
\newtheorem{theorem}{Theorem}
\newtheorem*{theorem*}{Theorem}
\newtheorem{proposition}[theorem]{Proposition}
\newtheorem*{proposition*}{Proposition}
\newtheorem{lemma}[theorem]{Lemma}
\newtheorem{corollary}[theorem]{Corollary}
\theoremstyle{definition}
\newtheorem{definition}[theorem]{Definition}
\newtheorem{assumption}[theorem]{Assumption}
\theoremstyle{remark}
\newtheorem{remark}[theorem]{Remark}

% Todonotes is useful during development; simply uncomment the next line
%    and comment out the line below the next line to turn off comments
% \usepackage[disable,textsize=tiny]{todonotes}
% \usepackage[textsize=tiny]{todonotes}
 
% Added by Serena: for tables
\usepackage{multirow}
\setlength{\tabcolsep}{3pt}
\usepackage{colortbl}

\usepackage[utf8]{inputenc} % allow utf-8 input
\usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{hyperref}       % hyperlinks
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{xcolor}         % colors

\usepackage{capt-of}
% \usepackage{floatrow}
% % Table float box with bottom caption, box width adjusted to content
% \newfloatcommand{capbtabbox}{table}[][\FBwidth]

\usepackage{boldline}

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

% xr package for referring to appendix externally
\usepackage{xr}

\makeatletter
\newcommand*{\addFileDependency}[1]{% argument=file name and extension
  \typeout{(#1)}
  \@addtofilelist{#1}
  \IfFileExists{#1}{}{\typeout{No file #1.}}
}
\makeatother

\newcommand*{\myexternaldocument}[1]{%
    \externaldocument{#1}%
    \addFileDependency{#1.tex}%
    \addFileDependency{#1.aux}%
}

\myexternaldocument{wang_601-supp}

\title{Robust Distillation for Worst-class Performance: \\ On the Interplay Between Teacher and Student Objectives}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is authomatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors

\author[1,2]{\href{mailto:<serenawang@google.com>?Subject=Your UAI 2023 paper}{Serena~Wang}{}}
\author[1]{\href{mailto:<hnarasimhan@google.com>?Subject=Your UAI 2023 paper}{Harikrishna~Narasimhan}{}}
\author[1]{\href{mailto:<yichenzhou@google.com>?Subject=Your UAI 2023 paper}{Yichen~Zhou}{}}
\author[3]{\href{mailto:<sarahooker@cohere.com>?Subject=Your UAI 2023 paper}{Sara~Hooker}{}}
\author[1]{\href{mailto:<mlukasik@google.com>?Subject=Your UAI 2023 paper}{Michal~Lukasik}{}}
\author[1]{\href{mailto:<adityakmenon@google.com>?Subject=Your UAI 2023 paper}{Aditya~Krishna~Menon}{}}
% Add affiliations after the authors
\affil[1]{%
    Google Research\\
    Mountain View, California and New York, New York, USA\\
}
\affil[2]{%
    University of California, Berkeley\\
    Berkeley, California, USA\\
}
\affil[3]{%
    Cohere For AI\\
    Palo Alto, California, USA
}
  
  \begin{document}
\maketitle

\begin{abstract}
Knowledge distillation is a popular technique that has been shown to produce remarkable gains in average accuracy. However,
recent work has shown that these gains 
are not uniform across subgroups in the data, and can often come at the cost of accuracy on rare subgroups and classes. 
Robust optimization is a common remedy to improve worst-class accuracy in standard learning settings, but in distillation
it is unknown whether it is best to apply
robust objectives when training the teacher, the student, or both. 
This work studies the interplay between robust objectives for the teacher and student.
Empirically, we show that
that jointly modifying the teacher and student objectives can lead to better worst-class student performance and even Pareto improvement in the trade-off between worst-class and overall performance.
Theoretically, we show that the \emph{per-class calibration} of teacher scores is key when training a robust student.
Both the theory and experiments support the surprising finding that applying a robust teacher training objective does not always yield a more robust student.
\end{abstract}

%%%%% INTRO %%%%%
\section{Introduction}
\label{sec:intro}

% Recent progress in computer vision and natural language processing has been characterized by ever larger models \citep{rae2021scaling,raffel2020exploring,bommasani2021opportunities}. While larger models have been shown to improve performance on a variety of tasks \citep{roberts2020}, the sheer size often makes deployment infeasible in resource constrained environments \citep{menghani2021efficient}. 

Knowledge distillation, wherein one trains a \emph{teacher} model and uses its predictions to train a \emph{student} model
of similar or smaller capacity, has proven to be a powerful tool that improves efficiency while achieving state-of-the-art classification accuracies %on a variety of problems 
\citep{Hinton:2015, Radosavovic:2018,46642,pham2021meta}.
% \todoAKM{(Can ignore, nit) The previous sentence has already mentioned KD improves SOTA accuracy, so perhaps here the thing to say is just ``Remarkably, the student accuracy under KD is capable of even surpassing that of the teacher.''}
Remarkably, the student accuracy under distillation is capable of even surpassing that of the teacher (e.g.\ \citet{xie2020self}). 

However, recent work has shown that the gains in average accuracy may not be uniform across subgroups, and can hurt performance on subgroups that are rarer or more difficult to classify. This is particularly true of long-tailed classification settings, where the improved average accuracy often comes at the cost of poorer accuracies on the tail classes
%\todoAKM{Do we mean to cite the Teacher's pet paper, not the Label smoothing paper?}
\citep{lukasik2021teachers, du2021compressed},
%or under-represented groups \citep{sagawa2020investigation, menon2020overparameterisation}, 
and model compression can further amplify these performance disparities
% %across groups 
\citep{hooker2020characterising, xu-etal-2021-beyond}.

%This is further exacerbated by the fact that many real-world datasets exhibit imbalance in the number of training examples belonging to different subgroups, such as \textit{long-tailed} classification settings where each target label delineates a subgroup \citep{menon2020long,feldman2021does}.%,d2021tale}.
% or under-represented groups \citep{sagawa2020investigation, menon2020overparameterisation}, and model compression can amplify performance disparities 
% %across groups 
% \citep{hooker2020characterising, xu-etal-2021-beyond}


% While originally devised for model compression, distillation has also been widely applied to improve the performance of 
% fixed capacity models and in semi-supervised learning settings \citep{rusu2016policy, Furlanello2018BornAN, yang2019training, xie2020self}. 

% Evaluating the trade-offs incurred by distillation techniques has largely focused on canonical measures such as average error over a given data distribution. Along this dimension, distillation has proven to be a remarkably effective %compression 
% technique, with the student achieving even better test performance than the teacher (e.g.\ \citet{xie2020self}). 
% % However, real-world deployments often involve multiple desiderata and constraints. These can be varied and involve minimizing worst-case error \citep{}, imposing a tailored group fairness constraint \citep{} or a floor on the desired recall and precision \citep{}.
% %
% % Recent work has shown the design objectives like compression can be at odds with minimizing sub-group objectives \citep{}. 
% % Motivate long tail learning
% %
% However, average error may differ significantly from the performance on individual subgroups, 
% %in the data,
% and thus may be an inadequate metric to optimize in real-world settings that also require good performance on different subgroups. For instance, subgroups may be defined by attributes such as country, language, or race \citep{Hardt:2016, zafar2017fairness, agarwal2018reductions}, in which case performance on the individual subgroups becomes a policy and fairness concern. % In a multi-class classification setting, each target class may be regarded a separate subgroup.

% The potential for mismatch between average and worst-group error is further exacerbated by the fact that many real-world datasets exhibit imbalance in the number of training examples belonging to different subgroups. This is particularly true with \textit{long-tailed} multi-class classification settings where each target label delineates a subgroup \citep{menon2020long,feldman2021does,d2021tale}. %VanHorn:2017,
% In such cases, the improved average error of the student often comes at the cost of poorer performance on the tail classes \citep{lukasik2021teachers,du2021compressed}
% or under-represented groups \citep{sagawa2020investigation, menon2020overparameterisation}, and model compression can amplify performance disparities 
% %across groups 
% \citep{hooker2020characterising, xu-etal-2021-beyond}. Further hurting %test
% performance is the possibility of the training data having poor representation of subgroups that occur more frequently at test time.
% %This distribution shift may be due to bias in training data collection and labeling.robust optimization

% To address this goal of achieving good worst-case performance across subgroups in a standard single model training setting, 
To mitigate the disparity between average and subgroup accuracy, 
a common remedy is to train a model to achieve low \emph{worst-group} test error. 
% Modified objectives for 
Suitably modified
robust optimization techniques have successfully achieved state-of-the-art worst-class performance with manageable computational overhead \citep{Sagawa2020Distributionally, sohoni2020no}. 
However, the evaluation of these techniques has thus far primarily focused on the standard training setting involving a single model.
In the increasingly popular distillation setting,
which involves both a teacher and student model,
% Indeed, 
there is limited understanding of how these approaches can be applied 
% in distillation 
to achieve the best trade-offs between average and worst-class performance.
In particular, it is unknown if the best results come from using a robust objective for the teacher, the student or {\emph{both}}. 

This work studies the interplay between robust training objectives for the teacher and  student. We focus on a multi-class classification setting where we define worst-class accuracy as the lowest per-class recall. Empirically, we show that jointly modifying \emph{both} the teacher and student objectives with robust objectives not only improves the worst-class accuracy of the student, but can provide Pareto improvements in the trade-off between average and worst-class performance. Theoretically, we analyze what makes a good teacher when training a robust student, and give to our knowledge the first concrete characterization of this by showing that the student's robustness depends on how \textit{well-calibrated} the teacher's scores are for the individual classes. %Both these results shed light on a surprising finding: using a robust training objective for the teacher does not necessarily yield a more robust student, and \proposalHari{indeed combining a standard teacher with a robust objective for the student often produces competitive results.}

Our contributions proceed as follows:
\begin{enumerate}[label=(\roman*),itemsep=0pt,topsep=0pt,leftmargin=16pt,nolistsep]
\item We begin with the problem setup (\S\ref{sec:prelims}), and adapt existing robust optimization objectives to a  distillation setting, allowing for different combinations of modifications to \emph{both} the teacher and student objectives (\S\ref{sec:distillation}). We provide adapted algorithms to address practical training issues that arise when applying robust objectives to both the teacher and student (such as margin-based surrogate losses and shared validation set usage).
%\item We empirically compare these robust distillation algorithms for introducing robust optimization into distillation, exploring the trade-off between overall performance and worst-class performance. We compare these to simpler baselines such as post-shifting and balancing the objective function.
\item We demonstrate empirically on benchmark image datasets that the different combinations of student and teacher objectives not only improve the student's worst-class accuracy, but yield better trade-offs between average and worst-class performance than baselines (\S\ref{sec:expts}). Perhaps surprisingly, we find that 
% using a robust objective to train the teacher does not necessarily result in a robust student, and
the teacher's worst-class accuracy is not always predictive of the teacher's ability to yield robust students.
%In both self-distillation and compression settings, we empirically demonstrate that the proposed modifications can train student models that yield better worst-class accuracies than the teacher and other recent baselines. We also demonstrate gains in the trade-off between overall and worst-class performance.
% (when training students of similar capacity).
% We show that our robust objectives 
% also achieve better trade-offs between overall and worst-class performance compared to baselines such as post-shifting and simpler loss modifications.
\item We show theoretically that the worst-class robustness of the student depends on the \emph{per-class calibration} of the teacher, and additionally  derive robustness guarantees for the student in terms of the teacher's errors % under  different algorithmic design choices 
(\S\ref{sec:theory}).
% , and % provide theoretical insights into what makes a ``good'' teacher when the goal is to train a robust student,
% and %derive robustness guarantees for the student in terms of the teacher's approximation errors.
% provide  insights into when the student yields better worst-class robustness than the teacher.
% .  These serve as a framework for understanding the reasons behind the comparative performance of different algorithms, and can guide the development of future improvements on robust distillation algorithms.
\end{enumerate}

\subsection{Related work}

\textbf{Worst-group robustness:}\ The goal of achieving good worst-case performance across subgroups can be framed as a 
(group) distributionally robust optimization (DRO) problem, 
and can be solved by iteratively updating costs on the individual groups and minimizing the resulting cost-weighted loss \citep{chen2017robust}. Recent variants of this approach have sought to avoid over-fitting through
group-specific regularization \citep{Sagawa2020Distributionally,sagawa2020investigation} or
margin-based losses \citep{narasimhan2021training, kini2021labelimbalanced}, and
to handle unknown subgroups \citep{sohoni2020no}.
% , and to balance between average and worst-case performance \cite{piratla2021focus}. Other approaches include the use of Conditional Value at Risk for worst-class robustness \citep{xu2020class}.
In the context of distillation, \citet{lukasik2021teachers} propose simple modifications to robustify the student's objective by controlling the strength of the teacher's labels for different groups. In contrast, we propose a more direct and theoretically-grounded procedure that
seeks to explicitly optimize for the student's worst-case error. %, and explore modifications to both the teacher and student objectives.


\textbf{Relationship to \citet{narasimhan2021training}:}\ 
%The work that most closely relates to our paper
This paper builds on the margin-based DRO framework of  \cite{narasimhan2021training}, 
who also include preliminary distillation experiments on training the teacher with
standard ERM and the student with a robust objective. 
However, this and other prior work \citep{lukasik2021teachers} have only explored modifications to the student loss, while training the teacher using a standard procedure.
{Our robust distillation proposals build on this method,
but carry out a more extensive analysis, exploring different combinations of teacher-student objectives 
and different trade-offs between average and worst-class performance}. Additionally,
we provide robustness guarantees for the student, equip the DRO algorithms to achieve different trade-offs between overall and worst-case error, and provide a rigorous analysis of different design choices, such as the use of teacher labels for the multiplier updates.
% {we explore \emph{two different variants} of their algorithm and
% derive robustness guarantees for the student in each case.}

\textbf{Long-tail learning.} There has been much work on training classifiers from long-tail data, ranging from modifications to loss modifications \citep{Cao:2019,menon2020long,cui2021parametric} %to data modifications \citep{Yin:2018,Zhang:2019} 
to architectural changes \citep{wang2020long, cui2022reslt}. All these methods focus on the standard single model training setup, and seek to maximize the balanced (and not the worst-class) accuracy. Recent attempts have sought to modify standard distillation for long-tail learning, by either re-balancing the student loss \citep{zhang2021balanced}, temperature-scaling the teacher predictions \citep{he2021distilling}, employing multiple teachers  \citep{xiang2020learning}, and leveraging the teacher's intermediate embeddings \citep{iscen2021class}. The common goal in most of these papers is to modify the student's objective to incorporate different forms of supervision from the teacher. In contrast, we seek to explore modifications to the teacher's training objective to improve the student's robustness.

\textbf{Role of the teacher's objective.} Few previous works have studied how the objective of the teacher affects the student performance.
For example, multiple works have studied the effect label smoothing objectives of the teacher model, some finding it to harm the student performance \citep{muller2019labelsmoothing}, improve the student \citep{shen2021is} or show varying impact depending on the temperature value \citep{RevisitingLabelSmoothing22}.
In another work, \cite{Lukasik2020} showed how applying noise correction objectives to the teacher often yield better result than only applying noise correction objectives in the student. 
We are not aware of a previous work studying the \emph{interplay} between the student and the teacher objectives on the robustness of the student.

%%%%% PRELIMS %%%%%
\section{Problem Setup}
\label{sec:prelims}

We consider a multi-class classification problem with instance space $\X$
and output space $[m] = \{1, \ldots, m\}$. Let $D$ denote the underlying data distribution over $\X \times [m]$, and 
$D_{\X}$ denote the marginal distribution over $\X$.
Let $\Delta_m$ denote the $(m-1)$-dimensional probability simplex over $m$ classes. 
We define the
conditional-class probability as $\eta_y(x) = \P(Y=y|X=x)$ and the class priors $\pi_y = \P(Y=y)$.
Note that $\pi_y = \E_{X \sim D_\X}\left[ \eta_y(X) \right]$. 

\textbf{Learning objectives.}
Our goal is to learn a multiclass classifier $h: \X \> [m]$ that maps an instance $x \in \X$ to one of $m$ classes.
We will do so by first learning a scoring function $f: \X \> \R^m$ that assigns scores $[f_1(x), \ldots, f_m(x)] \in \R^m$  to a given instance $x$,
and construct the classifier by predicting the class with the highest score: $h(x) = \argmax{j \in [m]}\,f_j(x)$. We will 
%compute confidence scores or probabilities by applying a
denote a softmax transformation of $f$ by $\softmax_y(f(x)) = \frac{\exp( f_y(x) ) }{\sum_j \exp( f_j(x) )}$,
and use the notation  $\softmax_y(f(x)) \propto z_y$ to indicate that $\softmax_y(f(x)) = \frac{z_y}{\sum_{j=1}^m z_j}$.

We measure the efficacy of the scoring function $f$ using a loss function $\ell: [m] \times \R^m \> \R_+$
that assigns a penalty $\ell(y, z)$ for predicting score vector $z \in \R^m$ for true label $y$. Examples of loss functions include the 
0-1 loss: $\ell^\zo(y, z) = \1\left(z \ne \argmax{j}\,f_j(x) \right)$,
% \begin{align}
% \ell^\zo(y, z) = \1\left(z \ne \argmax{j}\,f_j(x) \right),
% \label{eq:0-1}
% \end{align}
and the softmax cross-entropy loss: $\ell^\xent(y, z) = \textstyle -f_y(x) + \log\big( \sum_{j \in [m]} \exp\left( f_j(x) \right) \big)$.
% \begin{align}
% \ell^\xent(y, z) = \textstyle -f_y(x) + \log\left( \sum_{j \in [m]} \exp\left( f_j(x) \right) \right).
% \label{eq:xent}
% \end{align}

\textit{Standard objective:} A standard machine learning goal entails minimizing the overall expected risk: 
%This is reflected in the objective:
\begin{equation}
L^\std(f) = \E\left[\ell(Y, f(X))\right].
\label{eq:std}
\end{equation}
% In a standard setup, one often picks the classifier $h$ to minimize the expected 0-1 error:
% \[
% \err(h) = \E\left[\1(h(X) \ne Y)\right].
% \]
% Because this objective is non-continuous in $h(X)$, a common proxy is to
% pick a scoring function $f$ that minimizes the expected cross-entropy loss:
% One often chooses a scoring function $f$ that minimizes a 
% \begin{align}
% L^\std(f) = \E\left[\ell(Y, f(X))\right],
% \label{eq:std}
% \end{align}
% where $\ell: [m] \times \R^m \> \R_+$ is a loss function, 
% and to construct a classifier from the minimizer $f^*$ using: $h^*(x) = \argmax{i \in [m]}\,f^*_i(x)$.
%which we will refer to as the \emph{standard} objective.
%
\textit{Balanced objective:} In applications where the classes are severely imbalanced, i.e.,
the class priors $\pi_y$ are non-uniform and significantly skewed,
one may wish to instead optimize a \emph{balanced} version of the above objective,
where we average over the conditional loss for each class. Notice that  the
conditional loss for class $y$ is weighted by the inverse of its prior:
\allowdisplaybreaks
\begin{align}
    L^\bal(f) %&= \max_{y \in [m]} \frac{1}{\pi_y}\E_{x, y}\left[ \ell(y, f(x)) \right], \nonumber\\
            % &= \max_{y \in [m]} \frac{1}{\pi_y}\E_{X,Y}\left[\mathbf{1}(Y=y)\ell(Y, f(X))\right]\nonumber\\
            &= \frac{1}{m}\sum_{y \in [m]} \E\left[\ell(y, f(X))\,|\, Y = y \right]  \nonumber \\
            % \nonumber\\
            % ~=~ \frac{1}{m}\sum_{y \in [m]}  \frac{1}{\P(Y=y)}\E_{X, Y}\left[\1(Y=y)\ell(y, f(X))\right]\nonumber\\ 
         &= \frac{1}{m}\sum_{y \in [m]}  \frac{1}{\pi_y}\E_X\left[ \eta_y(X)\,\ell(y, f(X)) \right].
    \label{eq:balanced}
\end{align}
%
\textit{Robust objective:} A more stringent objective would be to focus on the worst-performing class, and 
minimize a \emph{robust} version of
 \eqref{eq:std} that computes
the worst among the $m$
conditional losses:
% \[
% \wer(h) = \max_{y \in [m]} \E\left[\1(h(X) \ne y) \,|\, Y = y\right].
% \]
% As a proxy for this objective, we seek to learn
% a scoring function that minimizes the following \emph{robust} version of \eqref{eq:std}:
\begin{align}
    L^\rob(f) %&= \max_{y \in [m]} \frac{1}{\pi_y}\E_{x, y}\left[ \ell(y, f(x)) \right], \nonumber\\
            % &= \max_{y \in [m]} \frac{1}{\pi_y}\E_{X,Y}\left[\mathbf{1}(Y=y)\ell(Y, f(X))\right]\nonumber\\
            % &= \max_{y \in [m]} \E\left[\ell(y, f(X))\,|\, Y = y \right]\nonumber\\
            % &= \max_{y \in [m]} \frac{1}{\P(Y=y)}\E\left[\1(Y=y)\ell(y, f(X))\right]\nonumber\\
         &= \max_{y \in [m]} \frac{1}{\pi_y}\E\left[ \eta_y(X)\,\ell(y, f(X)) \right].
    \label{eq:robust}
\end{align}
%which we will refer to as the \emph{robust} objective.
In practice, focusing solely on either the average or the worst-case performance may not 
be an acceptable solution, and therefore, in this paper, we will additionally seek to characterize the
trade-off between the balanced and robust objectives. One way to achieve this trade-off is to minimize the robust objective, while constraining the balanced objective to
be within an acceptable range. This constrained optimization can be equivalently formulated as optimizing a convex combination of the balanced and robust objectives, for trade-off %parameter 
$\alpha \in [0,1]$:
\begin{align}
    L^\tdf(f)
         &= (1 - \alpha) L^\bal(f) + \alpha L^\rob(f).
    \label{eq:trade-off}
\end{align}
 A similar trade-off can also be specified between the standard and robust objectives.
% A compromise between optimizing the overall error and the worst-case error may be
% to minimize the average conditional error, often referred to as \emph{balanced error rate}:
% \[
% \ber(h) = \frac{1}{m}\sum_{y \in [m]} \E\left[\1(h(X) \ne y) \,|\, Y = y\right],
% \]
% a proxy for which is given by the following \emph{balanced} loss:
% \begin{align}
%     L^\bal(f) %&= \max_{y \in [m]} \frac{1}{\pi_y}\E_{x, y}\left[ \ell(y, f(x)) \right], \nonumber\\
%             % &= \max_{y \in [m]} \frac{1}{\pi_y}\E_{X,Y}\left[\mathbf{1}(Y=y)\ell(Y, f(X))\right]\nonumber\\
%             % &= \max_{y \in [m]} \E\left[\ell(y, f(X))\,|\, Y = y \right]\nonumber\\
%             % &= \max_{y \in [m]} \frac{1}{\P(Y=y)}\E\left[\1(Y=y)\ell(y, f(X))\right]\nonumber\\
%          &= \frac{1}{m}\sum_{y \in [m]} \frac{1}{\pi_y}\E\left[ \eta_y(X)\,\ell(y, f(X)) \right].
%     \label{eq:balanced}
% \end{align}
%
%\textbf{Bayes-optimal scorers.} 
To better understand the differences between the standard, balanced and robust objectives in \eqref{eq:std}--\eqref{eq:trade-off},
we look at the optimal scoring function for each given a cross-entropy loss:
\begin{theorem}[\textbf{Bayes-optimal scorers}]
\label{thm:bayes}
When $\ell$ is %a proper-composite loss 
%that is strictly convex in its second argument (e.g., 
the cross-entropy loss $\ell^\xent$, %in \eqref{eq:xent},
the minimizers of \eqref{eq:std}--\eqref{eq:robust} over all measurable functions $f: \X \> \R^m$ are given
by:

\begin{tabular}{ll}
    \textit{(i)} $L^\std(f)$: &  $\softmax_y( f^{*}(x) ) \,=\, \eta_y(x)$\\
    \textit{(ii)} $L^\bal(f)$: &  $\softmax_y( f^{*}(x) ) \,\propto\, \frac{1}{\pi_y}\eta_y(x)$ \\
    \textit{(iii)} $L^\rob(f)$: & $\softmax_y( f^{*}(x) ) \,\propto\, \frac{\lambda_y}{\pi_y}\eta_y(x)$ \\
    \textit{(iv)} $L^\tdf(f)$: & $\softmax_y( f^{*}(x) ) \,\propto\, \frac{(1 - \alpha)\frac{1}{m} + \alpha\lambda'_y}{\pi_y}\eta_y(x),$
\end{tabular}

% \begin{tabular}{ll}
%     \emph{(a)} $L^\std(f)$: &  $\softmax_y( f^{*}(x) ) \,=\, \eta_y(x)$ \\
%     \emph{(b)} $L^\bal(f)$: &  $\softmax_y( f^{*}(x) ) \,\propto\, \frac{1}{\pi_y}\eta_y(x)$ \\
%     \emph{(c)} $L^\rob(f)$: & $\softmax_y( f^{*}(x) ) \,\propto\, \frac{\lambda_y}{\pi_y}\eta_y(x)$\\
%     \emph{(d)} $L^\tdf(f)$: & $\softmax_y( f^{*}(x) ) \,\propto\, \frac{(1 - \alpha)\frac{1}{m} + \alpha\lambda'_y}{\pi_y}\eta_y(x),$
% \end{tabular}

% \begin{enumerate}
%     \item The minimizer of $L^\std(f)$ is given by $f^{*}_y(x) = \log(\eta_y(x))$;
%     \item The minimizer of $L^\bal(f)$ is given by $f^{*}_y(x) = C\log\left(\frac{1}{\pi_y}\eta_y(x)\right)$ for some constant $C > 0$;
%     \item The minimizer of $L^\rob(p)$ is given by $f^{*}_y(x) = C'\log\left(\frac{\lambda_y}{\pi_y}\eta(x)\right)$ for constant $C' > 0$.
% \end{enumerate}
for class-specific constants $\lambda, \lambda' \in \R^m_+$ that depend on distribution $D$.
\end{theorem}

All proofs are provided in Appendix \ref{app:proofs}.
%The first result follows from the property of the cross-entropy loss \citep{williamson2016composite}; see \citet{menon2020long} for a proof of the second result; we provide a proof for the third and fourth results in Appendix \ref{app:proof-bayes}. 
Interestingly, the optimal scorers for all four objectives involve a simple scaling of the conditional-class probabilities $\eta_y(x)$.

%%%%% DISTILLATION %%%%%


\section{Distillation for worst-class performance}
\label{sec:distillation}


% \subsection{Combining Distillation Objectives for Worst-class Performance}

% Of particular interest to us is the use of 
% knowledge distillation for achieving robust performance.
We adopt the common
practice of training both the teacher and student on the same dataset. Specifically, given a training sample $S = \{(x_1, y_1), \ldots, (x_n, y_n)\}$
drawn from $D$, we first train a teacher model $p^t: \X \> \Delta_m$, and use it to generate a student dataset
$S' = \{(x_1, p^t(x_1)), \ldots, (x_n, p^t(x_n))\}$
by replacing the original labels with the teacher's predictions. %Subsequently, we 
We then train a student scorer $f^s: \X \> [m]$ using the re-labeled dataset,
and use it to construct the final classifier.

\textbf{Teacher and student objectives.}
%
% \subsection{Teacher and student objectives}
In a typical setting, both the teacher and student
are trained to optimize a version of the standard objective in \eqref{eq:std}, i.e.,
the teacher is trained to minimize the average loss against the original training labels,
and the student is trained to minimize an average loss against the teacher's predictions:
\begin{align}
        \text{Teacher: } & \hat{L}^\std(f^t) = \frac{1}{n}\sum_{i=1}^n \ell\left( y_i, f^t(x_i) \right); \hspace{10pt} \label{eq:standard-objectives} \\[-5pt]
        \text{Student: } & \hat{L}^\stdd(f^s) = \frac{1}{n}\sum_{i=1}^n \sum_{y=1}^m p_y^t(x_i)\, \ell\left(y , f(x_i) \right),\nonumber
    % \\[-10pt]
\end{align}
~\\[-8pt]
% \begin{tabular}{ll}
% Teacher: & 
% \begin{equation}
%     \displaystyle \hat{L}^\std(f^t) = \frac{1}{n}\sum_{i=1}^n \ell\left( y_i, f^t(x_i) \right)\hspace{35pt}
%     \label{eq:standard-objectives}
% \end{equation}
% \\
% Student: & $\displaystyle \hat{L}^\stdd(f^s) = \frac{1}{n}\sum_{i=1}^n \sum_{y=1}^m p_y^t(x_i)\, \ell\left(y , f(x_i) \right),$
% \end{tabular}
% \begin{align}
% \text{Teacher:}\hspace{5pt} \hat{L}^\std(f^t) &= \frac{1}{n}\sum_{i=1}^n \ell\left( y_i, f^t(x_i) \right)
% \label{eq:standard-objectives}
% \\
% \text{Student:}\hspace{5pt} \hat{L}^\stdd(f^s) &= \frac{1}{n}\sum_{i=1}^n \sum_{y=1}^m p_y^t(x_i)\, \ell\left(y , f(x_i) \right),
% \nonumber
% \end{align}
where $p^t(x) = \softmax(f^t(x))$. It is also common to have the student use a mixture of the teacher and one-hot labels. For concreteness, we consider a simpler distillation setup without this mixture, though extensions with this mixture would be straightforward to add.
This work takes a wider view and explores \textit{what combinations of student and teacher objectives} facilitate better worst-group performance for the student. Our experiments evaluate all \emph{nine} combinations of standard, balanced, and robust teacher objectives, paired with standard, balanced, and robust student objectives.

% One option is to change the teacher's objective to either the balanced or robust objectives in \eqref{eq:balanced}--\eqref{eq:robust}, with the goal of yielding better student performance on the tail or more difficult classes. Alternatively, we could alter the student objective to better focus on the more difficult classes.
Given the choice of teacher objective, the student will either optimize a distilled version of the balanced objective in \eqref{eq:balanced}:
\begin{align}
    {\hat{L}^\bald(f^s)}
    &= \frac{1}{m}\sum_{y \in [m]} \frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f^s(x_i) \right),
    \label{eq:balanced-distilled-empirical}
\end{align}
or a distilled version of the robust objective in \eqref{eq:robust}:
\begin{align}
    {\hat{L}^\robd(f^s)}
    &= \max_{y \in [m]} \frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f^s(x_i) \right).
    \label{eq:robust-distilled-empirical}
\end{align}
In practice, the teacher's predictions may have a different marginal distribution from the underlying class priors, particularly when temperature scaling is applied to the teacher's logits to soften the predicted probabilities \citep{narasimhan2021training}.  To address this, in both  \eqref{eq:balanced-distilled-empirical} and \eqref{eq:robust-distilled-empirical} we have replaced the class priors $\pi_y$ with the marginal distribution 
 $\hat{\pi}^t_y = \frac{1}{n}\sum_{i=1}^n p^t_y(x_i)$ from the teacher's predictions. 

% balanced
% or robust objectives in \eqref{eq:balanced}--\eqref{eq:robust}, so that
% it yields better performance on the tail or more difficult classes, and in turn 
% enables the student to also perform well on those classes.  Alternatively, 
% we could change the student objective to better focus on the more difficult classes.
% We could either have the student optimize 
% a distilled version of the balanced objective in \eqref{eq:balanced}:
% \begin{align}
%     {\hat{L}^\bald(f^s)}
%     &= \frac{1}{m}\sum_{y \in [m]} \frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f^s(x_i) \right),
%     \label{eq:balanced-distilled-empirical}
% \end{align}
% or a distilled version of the robust objective in \eqref{eq:robust}:
% \begin{align}
%     {\hat{L}^\robd(f^s)}
%     &= \max_{y \in [m]} \frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f^s(x_i) \right),
%     \label{eq:robust-distilled-empirical}
% \end{align}
% where we have replaced the class priors $\pi_y$ with the marginal distribution 
%  $\hat{\pi}^t_y = \frac{1}{n}\sum_{i=1}^n p^t_y(x_i)$ from the teacher's predictions. 
% This is because, in practice the teacher's predictions often have a different marginal distribution from the underlying
% class priors, particularly when temperature scaling is applied to the teacher's logits to soften the predicted probabilities \cite{narasimhan2021training}. 

In addition to exploring the combination of objectives that facilitates better worst-group performance for the student, we evaluate a more flexible approach -- have both the teachers and the students trade-off between the balanced and robust objectives:
% resulting in the following optimization objectives:
\begin{align}\label{eq:trade-off-distilled}
        \text{Teacher: } & \hat{L}^\tdf(f^t) = (1-\alpha^t)\hat{L}^\bal(f^t) + \alpha^t\hat{L}^\rob(f^t) \\
        \text{Student: } & \hat{L}^\tdfd(f^s) = (1-\alpha^s)\hat{L}^\bald(f^s) + \alpha^s\hat{L}^\robd(f^s),\nonumber
\end{align}
% \begin{tabular}{ll}
% Teacher: & 
% \hspace{-5pt}
% $
%     \displaystyle\hat{L}^\tdf(f^t) = (1-\alpha^t)\hat{L}^\bal(f^t) + \alpha^t\hat{L}^\rob(f^t)~~~~
%     % \label{eq:trade-off-distilled}
% $ \\
% Student: & $\displaystyle \hat{L}^\tdfd(f^s) = (1-\alpha^s)\hat{L}^\bald(f^s) + \alpha^s\hat{L}^\robd(f^s),$
% \end{tabular}
where $\hat{L}^\bal(f^t)$ and $\hat{L}^\rob(f^t)$ are the respective empirical estimates of
\eqref{eq:balanced} and \eqref{eq:robust} from the training sample, and $\alpha^t, \alpha^s \in [0,1]$
are the respective trade-off parameters for the teacher and student. %Through these experiments,
We are thus able to evaluate the Pareto-frontier of balanced and worst-case accuracies, obtained from  different combinations of the teachers and students, and 
trained with different trade-off parameters.
%  Of particular interest to us would be the use of the robust objective $\hat{L}^\robd(f)$ for the student,
% which we consider to be the more direct approach. We would first like
% to understand which teachers $p^t$ work well with this objective.

%%%%% ALGORITHMS %%%%%
\subsection{Robust Distillation Algorithms}
\label{sec:algorithms}
The different objectives we consider -- standard, balanced and robust -- entail different loss objectives to ensure efficient optimization during training.
% Having outlined our main distillation proposal, 
% we now describe the algorithms we use to optimize the different teacher and student objectives.
% During the training phase, however, we employ differentiable surrogates for efficient optimization.
For example, while training
the standard teacher and student in \eqref{eq:standard-objectives},
we take $\ell$ to be the softmax cross-entropy loss,
and optimize it using SGD. For the balanced and robust models,
we employ the margin-based surrogates that we detail below, which have shown to be more effective
in training over-parameterized networks \citep{cao2019learning, menon2020long, kini2021labelimbalanced}. 
Across all objectives, at evaluation we take the loss $\ell$ 
in the student and teacher objectives
to be the 0-1 loss. 



\textbf{Margin-based surrogate for balanced objective.}\
When the teacher or student model being trained is over-parameterized,
i.e., has sufficient capacity to correctly classify all examples in the training set,
the use of an outer weighting term 
in the objective
(such as the inverse class marginals in \eqref{eq:balanced-distilled-empirical})
can be ineffective. In other words, a model that yields zero training objective
would do so irrespective of what outer weights we choose. 
To remedy this problem,
we make use of the margin-based surrogate of \citet{menon2020long},
and incorporate the outer weights as margin terms within the loss. 
For
the balanced student objective in \eqref{eq:balanced-distilled-empirical}, this would look like:
\begin{align}
    {\widetilde{L}^\bald(f^s)}
    &= \frac{1}{n}\sum_{i=1}^n \cL^\mar\left(p^t(x_i), f^s(x_i); \1 / \hat{\pi}^t \right),
    \label{eq:margin-la}
\end{align}
~\\[-25pt]
%where $\hat{\pi}_y^t = \frac{1}{\frac{1}{n}\sum_{i=1}^n p^t_y(x_i)}$, and
\begin{align}
\lefteqn{
\text{where}~~~\cL^\mar\left(\bp, \boldf; \bc \right) =} \nonumber\\
&\frac{1}{m}\sum_{y \in [m]} p_y \log\bigg(1 + \sum_{j \ne y}\exp\left(\log(c_y / c_{j}) \,-\, (f_y - f_j) \right) \bigg),
\nonumber
\end{align}
for teacher probabilities $\bp \in \Delta_m$, student scores $\boldf \in \R^m$,
and per-class costs $\bc \in \R^m_+$. 
 For the balanced teacher, the margin-based objective would take a similar form, but %would instead
 with one-hot labels. 


We include a proof in Appendix \ref{app:calibration-mar} showing that
a scoring function that minimizes the surrogate objective
in \eqref{eq:margin-la}
also minimizes the
the balanced objective in \eqref{eq:balanced-distilled-empirical}
(when $\ell$ is the cross-entropy loss, and the student is chosen from a sufficiently flexible function class). In practice, the margin term $\log(c_y / c_{j})$ encourages a larger margin of separation for classes $y$ for which the cost $c_y$ is relatively higher.
 
\begin{figure}[t]
\begin{algorithm}[H]
\caption{Distilled Margin-based DRO}% for Robust Student}
\label{algo:dro}
\begin{algorithmic}
\STATE \textbf{Inputs:} Teacher $p^t$, Student hypothesis class $\cF$, Training set $S$, Validation set $S^\val$, Step-size $\gamma \in \R_+$,
Number of iterations $K$, Loss $\ell$,
Initial student $f^0 \in \cF$, Initial multipliers $\blambda^0 \in \Delta_m$
\STATE Compute $\hat{\pi}^{t}_j = \frac{1}{n}\sum_{(x, y) \in S} p_j^t(x),~ \forall j \in [m]$
\STATE Compute $\hat{\pi}^{t,\val}_j = \frac{1}{n^\val}\sum_{(x, y) \in S^\val} p_j^t(x),~ \forall j \in [m]$
\STATE \textbf{For}~{$k = 0 $ to $K-1$}
\STATE ~~~$\tilde{\lambda}^{k+1}_j \,=\, \lambda^k_j\exp\big( \gamma \hat{R}_j \big), \forall j \in [m]$%\todo{[Hari] Change step size in appendix to $\gamma$}
~~~~~\text{where} $\hat{R}_j =$ $\frac{1}{n^\val}\frac{1}{\hat{\pi}^{t,\val}_j}\underset{(x, y) \in S^\val}{\sum} p_j^t(x)\, \ell( j , f^k(x) )$\\[-20pt]
% \STATE ~~~~~~~~~$\begin{cases} 
%             \displaystyle\frac{1}{|S^\val|}\frac{1}{\hat{\pi}^t_j}\sum_{(x, y) \in S^\val} p_j^t(x_i)\, \ell( j , f^k(x) ) & \text{:A}\\
%             \displaystyle\frac{1}{|S^\val|}\frac{1}{\hat{\pi}_j}\sum_{(x, y) \in S^\val} \1(y=j)\,\ell( j , f^k(x) ) & \text{:B}
%         \end{cases}$
%\frac{1}{n^\val\hat{\pi}^t_y}\sum_{ p_y^t(x_i)\, \ell\left( y , f^k(x_i) \right)
\STATE ~~~$\lambda^{k+1}_y \,=\, \frac{\tilde{\lambda}^{k+1}_y}{\sum_{j=1}^m \tilde{\lambda}^{k+1}_j}, \forall y$
\STATE ~~~$f^{k+1} \,\in\, \Argmin{f \in \cF}\, \frac{1}{n}\sum_{i=1}^n \cL^\mar\left(p^t(x_i), f(x_i); \frac{\lambda^{k+1}}{\hat{\pi}^t} \right)$
~~~~~~// Replaced with a few steps of SGD
\STATE \textbf{End For}
\STATE \textbf{Output:} $\bar{f}^{s}: x \mapsto \frac{1}{K}\sum_{k =1}^K f^k(x)$
\end{algorithmic}
\end{algorithm}
\end{figure} 
 
 
\textbf{Margin-based DRO for robust objective.}\
% For the robust teacher, we the 
% procedure
% proposed by \citet{narasimhan2021training}, 
% which we refer to as worst-class distributionally-robust optimization (DRO).
% For the robust student, we consider two variants of this procedure, we elaborate next.
Minimizing the robust objective with plain SGD can be difficult due to the presence of
the outer ``max'' over $m$ classes. The key difficulty is in computing reliable stochastic gradients for
the max objective, especially given a small batch size. The
standard approach is to instead use a (group) distributionally-robust optimization (DRO) procedure,
which comes in multiple flavors \cite{chen2017robust, Sagawa2020Distributionally, kini2021labelimbalanced}.
We employ  the margin-based variant of group DRO  \citep{narasimhan2021training} as it naturally extends the 
margin-based objective used in the balanced setting.

We illustrate below how this applies to the robust student objective in \eqref{eq:robust-distilled-empirical}. The procedure for the robust teacher is similar, but involves one-hot labels.
For a student hypothesis class $\cF$,
%Parameterizing student $f$ with $\theta \in \R^d$,
we
first re-write the minimization in \eqref{eq:robust-distilled-empirical} 
over $f \in \cF$ into an equivalent min-max optimization using per-class multipliers $\lambda \in \Delta_m$:
\[
\min_{f \in \cF}\max_{\lambda \in \Delta_m} \sum_{y \in [m]}\frac{\lambda_y}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f(x_i) \right),
\]
and then
 maximize over $\lambda$ for fixed $f$, and minimize over $f$ for fixed $\lambda$:
%  $\lambda^{k+1}_y \propto 
%     \lambda^k_y\exp\bigg( \gamma\frac{1}{n\hat{\pi}^t_y}\sum_{i=1}^n  p_y^t(x_i) \ell\left( y , f^k(x_i) \right)\bigg)$; \\ $f^{k+1} \in \Argmin{f \in \cF} \sum_{y \in [m]}\frac{\lambda^{k+1}_y}{n\hat{\pi}^t_y}\sum_{i=1}^n  p_y^t(x_i) \ell\left( y , f(x_i) \right)$,
\allowdisplaybreaks
\begin{align*}
    \lambda^{k+1}_y &\propto\, 
    \lambda^k_y\exp\bigg( \gamma\frac{1}{n\hat{\pi}^t_y}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f^k(x_i) \right) \bigg), \forall y\\[-4pt]
    f^{k+1} &\in\, \Argmin{f \in \cF} \sum_{y \in [m]}\frac{\lambda^{k+1}_y}{n\hat{\pi}^t_y}\sum_{i=1}^n  p_y^t(x_i)\, \ell\left( y , f(x_i) \right),
\end{align*}
%
where $\gamma > 0$ is a step-size parameter. The updates on $\lambda$ 
implement exponentiated gradient (EG) ascent to maximize over the simplex \citep{shalev2011online}.

Following \citet{narasimhan2021training}, we make two modifications to the above updates 
when used to train over-parameterized networks that can fit the training set perfectly.
First, we perform the updates on $\lambda$ using a
small held-out validation set $S^\val \,=\, \{(x_1,y_1), \ldots, (x_{n^\val}, y_{n^\val})\}$, 
instead of the training set,
so that the $\lambda$s reflect how well the model generalizes out-of-sample.
Second, in keeping with the balanced objective, we modify the weighted objective
in the $f$-minimization step to include a margin-based surrogate. 
Algorithm \ref{algo:dro}
provides a summary of these steps
and returns a scorer
that averages over the $K$ iterates: $\bar{f}^s(x) = \frac{1}{K}\sum_{k=1}^K f^k(x)$.
While the averaging is needed for our theoretical analysis,
in practice, we find it sufficient to return the last scorer $f^{K}$. 
In Appendix \ref{app:dro-general-algo},
we describe how Algorithm \ref{algo:dro-general} can be easily modified to  trade-off
between the balanced and robust objectives, as shown in \eqref{eq:trade-off-distilled}.


%%%%
%%% Move to appendix
\if 0
\textbf{To distill the validation set or not?}\
The updates on $\lambda$
in Algorithm \ref{algo:dro} use a validation set labeled by the teacher. 
One could instead perform these updates with a curated validation set containing the original one-hot labels.
 Each of these choices presents different merits. The use of a teacher-labeled validation set is useful
 in many real world scenarios where labeled data is hard to obtain, % \citep{Khan2021,dixon2017,veale_2017}, 
 while unlabeled data abounds.
  In contrast, the use of one-hot validation labels, although more expensive to obtain,  may
make the student more immune to errors in the teacher's predictions, 
as the coefficients $\lambda$s are now based on an unbiased estimate of 
the student's performance on each class.

%
% The use of a validation set for updates on $\lambda$ brings up an interesting
% question: 
% \textit{Would the student be better-off having the validation set also labeled by the teacher?}
% This would be advantageous in many real world scenarios where labeled data is hard to obtain \citep{Khan2021,dixon2017,veale_2017}, while unlabeled data abounds.
% We compare this variant to performance given a curated validation set with the original one-hot labels. While more expensive to obtain a curated validation set, the use of one-hot validation labels may
% %enjoy the best of both worlds: it would benefit from the teacher's 
% % confidence  scores on the training examples, and also be able to make use of the one-hot labels 
% % in the validation set to get an unbiased estimate of its performance on each class. This way, the student
% make the student more immune to errors in the teacher's predictions, 
% as the coefficients $\lambda$s are now based on an unbiased estimate of 
% the student's performance on each class.
%
%  or retaining the original one-hot labels in the validation set, if available?
%  We will experiment with both these versions of DRO, andfor the latter
%
With a one-hot validation set, we update $\lambda$s as follows: %when using a one-hot validation set:
\begin{align}
    &\lambda^{k+1}_j \,=\, \lambda^k_j\exp\big( \eta \hat{R}_j \big), \forall j \in [m],\nonumber\\ 
    &\text{where}~
\hat{R}_j = \frac{1}{n^\val}\frac{1}{\hat{\pi}_j}\sum_{(x, y) \in S^\val} \1(y=j)\,\ell( j , f^k(x) ),
\label{eq:one-hot-val}
\end{align}
for estimates $\hat{\pi}_y \approx \pi_y$ of the original class priors.
% : (A) we
% can either use
% the teacher's predictions $p^t(x)$ to label the validation set, and (B) or retain
% the original one-hot labels in the validation set (along with estimates of the
% original class priors $\hat{\pi}_y \approx \pi_y$ for the normalizations). 
%
We analyze both the variants in our experiments, and will discuss robustness guarantees for each.
% Interestingly, \citet{narasimhan2021training} also work with one-hot validation labels
% in their distillation experiments, but do not provide a reasoning for their choice. 
%We revisit the use of
%  one-hot validation labels in the next section. %, and attempt to provide a theoretical justification for it.
\fi

%%%%% EXPERIMENTS %%%%%
\section{Experiments}
\label{sec:expts}
% We evaluate our proposal on both balanced and long-tail image datasets with varying numbers of classes.
% To empirically understand the role of the teacher in training robust students, we work with a standard distillation setup where the student receives soft labels from the teacher. 
% (we discuss other specialized forms of supervision, such as intermediate teacher embeddings, in the conclusion).
To empirically understand the interplay of teacher and student objectives, we explore the following questions:
% Empirically, we explore questions such as, 
\textit{what combination of teacher and student objectives yield the highest worst-class accuracy?} 
\textit{Can some combinations improve worst-class accuracy without sacrificing average accuracy?}

% Such applications of robust optimization for satisfying fairness constraints across subgroups have for the most part focused on smaller models in a standard classification setting with experiments on small datasets such as COMPAS, Adult, and other UCI datasets, but the performance of these robust and constrained optimization techniques have not been explored in the setting of large datasets and models that necessitate distillation and compression. 
\textbf{Datasets.} 
% We evaluate each robust distillation objective across different image dataset benchmarks: 
We evaluate the proposed distillation protocols on benchmark image datasets: \textit{(i)} CIFAR-10, \textit{(ii)} CIFAR-100 \citep{Krizhevsky09learningmultiple}, \textit{(iii)} TinyImageNet  (a subset of ImageNet with 200 classes) \citep{le2015tiny}, and \textit{(iv)} ImageNet \citep{ILSVRC15}. We also include long-tailed versions of the first three datasets created by downsampling tail classes \citep{cui2019class}. For both the original and long-tailed versions of the datasets,
there are often biases in worst-class performance, possibly due to some classes being easier to learn \citep{lukasik2021teachers, hooker2020characterising}.
% Details on sampling the long tailed versions are given in Appendix \ref{app:experiment_details}.
%\textbf{Test/validation splits:} 
For all datasets, as done in prior work \citep{menon2020long, narasimhan2021training}, we randomly split the original default test set in half to create a validation set and test set, and use the same validation and test sets for the long-tailed training sets as for the original versions. % \citep{menon2020long, narasimhan2021training}. 
% For the long tailed training sets, this simulates a scenario where the training data follows a long tailed distribution, but the test distribution of interest still comes from the original balanced data distribution. 
%Thus, for the long tailed experiments, the average accuracy on test is measured on a dataset with balanced classes.


\textbf{Architectures.}  We evaluate our distillation protocols in both a self-distillation and compression setting. On all CIFAR datasets, all teachers were trained with the ResNet-56 architecture and students were trained with either ResNet-56 or ResNet-32. 
On TinyImageNet and ImageNet, teachers and students were trained with ResNet-18.
More details on these architectures can be found in \citet{lukasik2021teachers} and \citet{he2016deep} (see, e.g., Table 7 in \citet{lukasik2021teachers}). Self-distillation results are reported in the main paper (teacher/student share the same architecture), and we include results with compressed students in Appendix \ref{app:experiment_results}.

\textbf{Hyperparameters.} We apply temperature scaling to the teacher scores, i.e., compute $p^t(x) = \softmax(f^t(x) / \gamma)$, and  vary the temperature parameter $\gamma$  over a range of $\{1, 3, 5\}$.
A higher temperature produces a softer probability distribution over classes \citep{Hinton2015DistillingTK}. 
Unless otherwise specified, the temperature hyperparameters were chosen to achieve the highest worst-class accuracy on the validation set. 
We closely mimic the learning rate and regularization settings from prior work \citep{menon2020long,narasimhan2021training} (see Appendix \ref{app:experiment_details} for details). 
% In keeping with the theory, the regularization ensures that the losses are bounded (see Appendix \ref{app:experiment_details} for further details). 
% For all experiments in the main paper, we use teacher labels in the validation set as discussed at the end of Section \ref{sec:algorithms}, with additional results using onehot labels in the validation set in Appendix \ref{app:experiment_results}.

\textbf{Which objective combinations are most robust?}
We begin by exploring the effect of the interaction between student and teacher objectives on worst-class accuracy. In Table \ref{tab:combos_shortened}, we search over combinations of the standard, balanced, and robust objectives for the teacher ($L^{\std}, L^{\bal}, L^{\rob}$) and the student ($L^{\stdd}, L^{\bald}, L^{\robd}$) (note that on the original datasets, $L^{\std}$ is equivalent to $L^{\bal}$). For each combination, 
following prior conventions in long-tailed learning \citep{menon2020long,lukasik2021teachers},
we report the \textit{average accuracy} over all classes, and the \textit{worst-class accuracy}, or minimum per-class recall over all classes (see \eqref{eq:robust}). For datasets with a long tail or high number of classes, we also report the \textit{worst-k accuracy}, which is the average of the the worst $k$ per-class recalls. 

The first surprising finding in Table \ref{tab:combos_shortened} is that \textit{applying the robust objective twice isn't always best.} For all but one dataset, the $L^{\rob}/L^{\robd}$ teacher/student combination was outperformed by some other combination of either $L^{\std}/L^{\robd}$, $L^{\rob}/L^{\stdd}$, or $L^{\bal}/L^{\robd}$. Still, in the winning combination, at least one of the objectives was robust. This suggests that while the robust objective is effective for controlling worst-class accuracy, there may be some information loss in applying it twice to both the teacher and student. 

To understand this information loss on the teacher's side, we highlight a second surprising finding that \textit{the teacher with the best worst-class accuracy alone did not always produce the student with the best worst-class accuracy}. The robust teacher had the highest worst-class accuracy across all datasets, but for CIFAR-10 and all three long-tailed datasets, it was actually the $L^{\std}$ or $L^{\bal}$ teacher that produced the best robust student. This shows that there is more to a good teacher than just having good worst-class performance -- in fact, we show theoretically in Section \ref{sec:theory} that the property of the teacher that is most important for robust student performance is a form of \textit{calibration} of per-class scores.

% \input{tables/table_combos_shortened} %%%%%
\begin{table*}[!ht]
\caption{Worst-class accuracy comparisons for different combinations of teacher/student objectives. Worst-1 test accuracy is reported (worst-10 for TinyImageNet-LT) (best in \textbf{bold}), and average test accuracy is shown in parentheses. Mean accuracies are reported over repeat trainings (see extended table in Appendix for standard errors). Note that on the original datasets, $L^{\std}$ and $L^{\stdd}$ are equivalent to $L^{\bal}$ and $L^{\bald}$.
%. 
% We include results for the robust student using either a teacher labeled validation set (``teacher val''), or true one-hot class labels in the validation set (``one-hot val''), as outlined in Eq.~(\ref{eq:one-hot-val}). 
% Perhaps surprisingly, the teacher with the best worst-class accuracy alone (the ``None'' row) did not always produce the student with the best worst-class accuracy.
}
\label{tab:combos_shortened}
\begin{center}
\setlength{\tabcolsep}{6pt}
\begin{tabular}{p{0.2cm}cV{2.5}c|cV{2.5}c|cV{2.5}c|cV{2.5}}
\toprule
& & \multicolumn{2}{cV{2.5}}{\small{\textbf{CIFAR-10} Teacher Obj.}} & \multicolumn{2}{cV{2.5}}{\small{\textbf{CIFAR-100} Teacher Obj.}} & \multicolumn{2}{cV{2.5}}{\small{\textbf{TinyImageNet} Teacher Obj.}} \\
\multirow{4}{*}{\rotatebox{90}{\small{Student Obj.}}} & & $L^{\std}$ & $L^{\rob}$ & $L^{\std}$ & $L^{\rob}$ & $L^{\std}$ & $L^{\rob}$ \\
\cline{2-8}
& None & $86.48$ \tiny{($93.74$)}  & $90.09 $ \tiny{($92.67$)} &  $42.22 $ \tiny{($72.42$)}  & $43.42 $ \tiny{($68.81$)}   & $8.42$  \tiny{($56.79$)} & $11.87$ \tiny{($48.40$)} \\
\cline{2-8}
& $L^{\stdd}$ & $87.66 $ \tiny{($94.34$)}  & $90.12 $ \tiny{($94.07$)}  & $43.81$ \tiny{($74.61$)} & $\mathbf{45.33}$ \tiny{($73.67$)} & $6.32$ \tiny{($57.83 $)} & $10.53$ \tiny{($55.36$)}\\
\cline{2-8}
& $L^{\robd}$  & $\mathbf{90.94} $ \tiny{($92.54$)} & $85.14 $ \tiny{($89.58$)} & $42.96$ \tiny{($68.71$)} & $27.59$ \tiny{($54.79$)} & $9.98$ \tiny{($49.84$)} & $\mathbf{16.58}$ \tiny{($46.11$)} \\
% \cline{2-8}
% & $L^{\robd}$ & $89.37 $  & $87.32 $  & $40.36$   & $42.68$ & $16.27$ & $17.36$ \\
% &\tiny{(one-hot val)} & \tiny{($91.63$)} & \tiny{($91.16$)} & \tiny{($61.49$)} & \tiny{($62.03$)} & \tiny{($48.06$)} & \tiny{($43.92$)}  \\
\bottomrule
\end{tabular}

\setlength{\tabcolsep}{3pt}
\begin{tabular}{p{0.15cm}cV{2.5}c|c|cV{2.5}c|c|cV{2.5}c|c|cV{2.5}}
\toprule
& & \multicolumn{3}{cV{2.5}}{\small{\textbf{CIFAR-10-LT} Teacher Obj.}} & \multicolumn{3}{cV{2.5}}{\small{\textbf{CIFAR-100-LT} Teacher Obj.}} & \multicolumn{3}{cV{2.5}}{\small{\textbf{TinyImageNet-LT} Teacher Obj.}} \\
% & & \multicolumn{3}{c||}{Teacher Obj.} & \multicolumn{3}{c||}{Teacher Obj.} \\
\multirow{6}{*}{\rotatebox{90}{\small{Student Obj.}}}  & & $L^{\std}$ & $L^{\bal}$ & $L^{\rob}$ & $L^{\std}$ & $L^{\bal}$ & $L^{\rob}$ & $L^{\std}$ & $L^{\bal}$ & $L^{\rob}$\\
\cline{2-11}
& None & $57.26 $ \tiny{($76.27 $)} & $68.52 $ \tiny{($79.85 $)} & $74.80 $ \tiny{($80.29 $)}  &   $0.00 $ \tiny{($43.33 $)} & $3.75 $ \tiny{($47.55 $)} & $10.33 $ \tiny{($44.27 $)} &  $0.00 $ \tiny{($33.15 $)} & $2.11$ \tiny{($35.96 $)} & $4.92$ \tiny{($27.23 $)}\\
\cline{2-11}
& $L^{\stdd}$ & $36.67  $ \tiny{($69.50 $)} & $66.96  $ \tiny{($79.25 $)}  & $71.15  $ \tiny{($80.95 $)} & $0.00 $ \tiny{($43.86 $)}  & $2.39  $ \tiny{($48.95 $)} & $7.32  $ \tiny{($47.93 $)} &  $0.00 $ \tiny{($26.05 $)} & $0.00$ \tiny{($27.21 $)} & $1.87$ \tiny{($25.34 $)}\\
\cline{2-11}
& $L^{\bald}$ & $71.23  $ \tiny{($80.50 $)} & $70.52  $ \tiny{($81.12 $)} & $72.96  $ \tiny{($80.71 $)} &$4.39  $ \tiny{($50.40 $)} & $7.08  $ \tiny{($50.10$)} & $7.19  $ \tiny{($47.51$)} &$0.20 $ \tiny{($30.43 $)}  & $2.82$ \tiny{($39.41 $)} & $4.77$ \tiny{($38.41 $)} \\
\cline{2-11}
& $L^{\robd}$  & $63.85  $ \tiny{($76.81 $)}  &$\mathbf{75.56} $ \tiny{($80.81$)} & $69.21  $  \tiny{($76.72 $)} & $9.05  $ \tiny{($33.75 $)}  & $\mathbf{12.52}  $ \tiny{($34.05 $)} & $10.32  $ \tiny{($36.83 $)} & $0.00 $ \tiny{($22.66 $)} & $\mathbf{4.93}$ \tiny{($35.43 $)} & $3.32 $  \tiny{($25.11 $)} \\
\bottomrule
\end{tabular}
\end{center}
\vskip -0.2in
\end{table*}

% Perhaps surprisingly, it did not always benefit the robust student ($L^{\robd}$) to utilize the true one-hot labels in the validation set. Instead, training the robust student with teacher labels on the validation set was sufficient to achieve the best worst-class performance. This is promising from a data efficiency standpoint, since it can be expensive to build up a labeled dataset for validation, especially if the training data is long-tailed.

\textbf{Trading off accuracy and robustness.} Table \ref{tab:combos_shortened} focuses on worst-class accuracy, but practitioners often must consider the trade-off between average accuracy and worst-class accuracy when deploying any model. To address this, we introduced the $L^{\tdf}/L^{\tdfd}$ objectives for the teacher/student with trade-off parameters $\alpha^t, \alpha^s$. Figure \ref{fig:alphas_cifar10} plots average and worst-class accuracies for a full spread of $\alpha^t, \alpha^s$ parameters. First, we note that lower $\alpha^s$ usually leads to higher average accuracy (this is not always the case for $\alpha^t$, which we show in more detail in Appendix \ref{app:experiment_results}). Figure \ref{fig:alphas_cifar10} also shows that combinations of $\alpha^t, \alpha^s$ yield a roughly concave Pareto frontier of solutions with different average and worst-class accuracies to choose from.
Selecting the best combination of trade-off parameters $\alpha^t, \alpha^s$ in practice depends on domain-specific decisions regarding the importance of worst-class vs. average accuracy. Any selection criteria based on some trade-off of worst-class vs. average accuracy can be applied over the validation set to select $\alpha^t, \alpha^s$ as hyperparameters. We demonstrate one such set of selection criteria here: in Table \ref{tab:baselines}, we select $\alpha^t, \alpha^s$ to maximize worst-class accuracy on validation, subject to having at least as high average accuracy as standard distillation (within error margin) on the validation set. Other candidate criteria include weighted sums of worst-class accuracy and average accuracy, or constrained optimization criteria from \cite{cotter2019optimization}.

% \input{figures/alphas_cifar10}
\begin{figure}[t]
    \centering
        \includegraphics[width=0.9\columnwidth]{plot_students_only_cf10_arch56_teacher} 
    \caption{All $\alpha^t, \alpha^s$ combinations for CIFAR-10 on test. The black line traces out the Pareto frontier. Average accuracy is roughly determined by $\alpha^s$. The labeled point corresponds to the ``best'' combination selected in Table \ref{tab:baselines} based on validation criteria, but other domain-specific trade-off criteria could yield any of these other points.}
    \label{fig:alphas_cifar10}
\end{figure}

\textbf{Comparison to baselines.}
Finally, we contextualize the performance of the proposed $L^\tdf/L^{\tdfd}$ objectives and the training protocol in Algorithm \ref{algo:dro} by comparing to several state-of-the-art methods. In addition to \textit{standard distillation} (training the teacher with $L^{\std}$ and the student with $L^{\stdd}$), we compare the proposed objective combinations with two recent works focusing on robust distillation \citep{lukasik2021teachers, narasimhan2021training}, both of which use a standard objective for the teacher and modify only the student objective for worst-class performance. 
%Compared to these works, our experiments illustrate the effects of modifying the teacher's learning objective. 
From \cite{narasimhan2021training}, we consider the following two methods: \textit{(i)} \textit{Post-shifting:} this non-distillation approach directly constructs a new scoring model by making post-hoc adjustments to the teacher, so as to maximize the robust accuracy on the validation sample. \textit{(ii)} \textit{Robust student:}
this approach trains a student using $L^{\robd}$ from a standard teacher. 
From \cite{lukasik2021teachers}, we compare to their two proposed \textit{AdaMargin} and \textit{AdaAlpha} methods.
% This mitigation technique proposes a class specific weight that mixes a weighted combination of teacher
% prediction and the true one-hot labels,
% with the weights chosen using the validation set. 
%When the weight is high, the student leans more heavily unto the one-hot training labels themselves. 
Both methods are motivated by the observation that the margin defined for each class $y$ by $\gamma_{\rm avg}( y, p^{\rm t}( x ) ) = p^{\rm t}_y( x ) - \frac{1}{m - 1} \sum_{y' \neq y} p^{\rm t}_{y'}( x )$ 
correlates with whether distillation improves over one-hot training \citep{lukasik2021teachers}. 
AdaMargin uses that quantity as a margin in the distillation loss, whereas AdaAlpha uses it to adaptively mix between the one-hot and distillation losses.
Additionally, for long-tailed datasets, we include a comparison to \citet{menon2020long} which we refer to as \textit{balanced student}, where the student is distilled with a balanced objective $L^{\bald}$ from a standard teacher.
Finally, we also include a comparison to 
the
\textit{Group DRO} method for
subgroup robustness
without distillation  (Algorithm 1 in \citet{Sagawa2020Distributionally}). This method differs from our DRO procedure in that they do not apply a margin-based loss. %, and do not apply an additional validation set to compute per-class accuracies during optimization.

% \input{tables/table_baselines}
%TODO: try to combine these into one table
\begin{table*}[!ht]
    \centering
    \caption{Comparison to baselines for the selected $\alpha^t, \alpha^s$ combination on test data.}\label{tab:baselines}
    \begin{tabular}{lcc|cc|cc}
    \toprule
    & \multicolumn{2}{c}{\textbf{CIFAR-10}} & \multicolumn{2}{c}{\textbf{CIFAR-100}} & \multicolumn{2}{c}{\textbf{TinyImageNet}} \\
    Method & Average acc.  & Worst-1 acc. & Average acc.  & Worst-1 acc. & Average acc.  & Worst-1 acc. \\
    \midrule 
    Selected $\alpha^t, \alpha^s$ combo & $\mathbf{94.28} \pm 0.06$ & $\mathbf{90.11}\pm 0.23$   & $73.22 \pm 0.26$ & $\mathbf{48.40} \pm 1.47$ & $\mathbf{58.09} \pm 0.13$ & $9.47 \pm 1.76$  \\
    Standard distillation & $\mathbf{94.34} \pm 0.07$ & $87.66 \pm 0.40$ & $\mathbf{74.61} \pm 0.15$ & $43.81 \pm 0.58$ & $57.83 \pm 0.13$ & $6.32 \pm 2.31$ \\ 
    Post shift \scriptsize{\text{[NM'21]}} & $92.16 \pm 0.18$ & $88.60 \pm 0.35$ & $61.22 \pm 0.36$ & $38.19 \pm 0.40$  & $43.02 \pm 0.79$ & $14.39 \pm 1.13$\\ 
    Robust student \scriptsize{\text{[NM'21]}} & $92.72 \pm 0.05$ & $89.90 \pm 0.21$ & $68.45 \pm 0.13$ & $43.62 \pm 1.27$ & $48.06 \pm 0.24$ & $\mathbf{16.27} \pm 0.43$\\ 
    AdaMargin \scriptsize{\text{[LBMK'22]}} & $93.69 \pm 0.06$ & $88.42 \pm 0.36$ & $73.58 \pm 0.11$ & $43.91 \pm 1.11$ & $52.45 \pm 0.08$ & $\mathbf{15.41} \pm 0.71$\\ 
    AdaAlpha \scriptsize{\text{[LBMK'22]}} & $\mathbf{94.31} \pm 0.01$ & $88.33 \pm 0.14$ & $74.15 \pm 0.08$ & $45.46 \pm 0.67$ & $57.22 \pm 0.08$ & $7.62 \pm 2.17$\\ 
    Group DRO \scriptsize{\text{[SKHL'20]}} & $92.34 \pm 0.07$ & $89.32 \pm 0.21$ & $65.18 \pm 0.08$ & $43.89 \pm 1.12$ & $48.78 \pm 0.21$ & $11.38 \pm 1.79$\\ 
    \bottomrule
    \end{tabular}

    \begin{tabular}{lcc|cc|cc}
    \toprule
    & \multicolumn{2}{c}{\textbf{CIFAR-10-LT}} & \multicolumn{2}{c}{\textbf{CIFAR-100-LT}} &  \multicolumn{2}{c}{\textbf{TinyImageNet-LT}}\\
    Method & Average acc.  & Worst-1 acc. & Average acc.  & Worst-1 acc. & Average acc.  & Worst-10 acc. \\
    \midrule 
    Selected $\alpha^t, \alpha^s$ combo & $79.02 \pm 0.08$ & $\mathbf{75.43} \pm 0.39$  & $43.94 \pm 0.16$ & $\mathbf{14.52} \pm 0.68$  & $26.91 \pm 0.16$ & $\mathbf{6.04} \pm 0.25$  \\
    Standard distillation & $77.39 \pm 0.10$ & $60.12 \pm 0.56$ & $\mathbf{46.01} \pm 0.16$ & $0.00 \pm 0.00$ & $26.05 \pm 0.18$ & $0.00 \pm 0.00$ \\ 
    Post shift \scriptsize{\text{[NM'21]}} & $78.28 \pm 0.05$ & $74.33 \pm 0.09$ & $29.88 \pm 0.61$ & $10.01 \pm 0.72$ & $21.32 \pm 0.49$ & $2.58 \pm 0.42$ \\ 
    Robust student \scriptsize{\text{[NM'21]}} & $80.05 \pm 0.13$ & $74.91 \pm 0.24$ & $30.79 \pm 0.18$ & $12.28 \pm 0.46$ & $21.59 \pm 0.19$ & $1.55 \pm 0.37$ \\ 
    Bal. student \scriptsize{\text{[MJRJVK'21]}} & $\mathbf{81.36} \pm 0.14$ & $71.60 \pm 0.38$ & $50.40 \pm 0.12$ & $4.39 \pm 0.66$ & $\mathbf{30.43} \pm 0.06$ & $0.20 \pm 0.18$ \\ 
    AdaMargin \scriptsize{\text{[LBMK'22]}} & $72.69 \pm 0.24$ & $47.52 \pm 0.95$ & $31.26 \pm 0.21$ & $0.00 \pm 0.00$ & $4.41 \pm 0.09$ & $0.00 \pm 0.00$ \\ 
    AdaAlpha \scriptsize{\text{[LBMK'22]}} & $70.83 \pm 0.28$ & $43.64 \pm 1.09$ & $42.52 \pm 0.08$ & $0.00 \pm 0.00$ & $27.95 \pm 0.14$ & $0.00 \pm 0.00$ \\ 
    Group DRO \scriptsize{\text{[SKHL'20]}} & $74.39 \pm 0.17$ & $59.93 \pm 0.59$ & $40.47 \pm 0.17$ & $0.19 \pm 0.17$ & $27.78 \pm 0.13$ & $0.00 \pm 0.00$ \\ 
    \bottomrule
    \end{tabular}
\end{table*}

Table \ref{tab:baselines} shows the average and worst-class accuracies on test for these baselines compared to the combination of $\alpha^t, \alpha^s$ selected using the selection criteria previously described. The selection criteria for $\alpha^t, \alpha^s$ are applied over the validation set, and thus do not directly translate to test performance: the selected $\alpha^t, \alpha^s$ combination sometimes has lower average test accuracy than standard distillation. Still, overall, the selected $\alpha^t, \alpha^s$ combination is Pareto efficient compared to all other baselines (dominant in at least one of average accuracy or worst-$k$ accuracy). Among the rest of the different $\alpha^t, \alpha^s$ candidates (as in Figure \ref{fig:alphas_cifar10}), there actually exist combinations that Pareto dominate all baselines in test performance (additional plots in Appendix \ref{app:experiment_results}). While we only show results from our simple example selection criteria in Table \ref{tab:baselines}, this suggests that there is room for alternative selection criteria to yield even better results. The challenge, as with all hyperparameter selection, is that selection on the validation set comes with a generalization gap between validation and test.

%%%%% THEORY %%%%%

\section{Theoretical Analysis}
\label{sec:theory}
Complementing our empirical findings, our theoretical analysis explores 
what constitutes a good teacher and how it
aids a student in achieving robustness. 
To simplify our exposition, we present our theoretical analysis 
for a student trained using Algorithm \ref{algo:dro} to yield good worst-class performance. Our results easily extend to the case where the student seeks to trade-off between average and worst-case performance. 

\textbf{What constitutes a good teacher?}
%\label{sec:theory_good_teacher}
We first characterize the properties of a good teacher 
when the student's goal is to minimize the robust population objective $L^\rob(f^s)$ in \eqref{eq:robust}. In particular, does the student's ability to perform well on this worst-case objective depend on
the teacher also performing well on the same objective?
%
Given scores from a teacher $p^t$, the student minimizes the robust distillation objective $\hat{L}^\robd(f^s)$ in \eqref{eq:robust-distilled-empirical}, and uses this as a proxy for the actual objective $L^\rob(f^s)$ we care about.
% , the
% student minimizes the distilled objective $\hat{L}^\robd(f^s)$ in \eqref{eq:robust-distilled-empirical}
Intuitively, an \emph{ideal} teacher 
would then be one that provides a good proxy for the student, and ensures that
the difference
$
|\hat{L}^\robd(f^s) - L^\rob(f^s)|
$
is as small as possible. Below, we provide a simple bound on this difference:

\begin{theorem}
\label{thm:good-teacher}
% Define population version of the student's objective in \eqref{eq:robust-distilled-empirical} as:
% \begin{align*}
%     {L}^\robd(f^s) &= \max_{y \in [m]} \frac{1}{\pi^t_y}\E_x\left[ p_y^t(x)\, \ell\left( y , f^s(x) \right)\right].\hspace{-5pt}
%     \label{eq:robust-distilled-population}
% \end{align*}
Suppose $\ell(y, z) \leq B, \forall x \in \X$ for some $B > 0$. 
Let $\pi^t_y = \E_x\left[ p_y^t(x) \right]$,
and let the following denote the per-class expected and empirical student losses respectively: 
% $\phi_y(f^s) = \textstyle\frac{1}{\pi^t_y}\E_x\left[ p_y^t(x)\, \ell\left( y , f^s(x) \right)\right];$ \\
% $\hat{\phi}_y(f^s) = \textstyle\frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\,\ell\left( y , f^s(x_i) \right)$. \\
\begin{align*}
&\phi_y(f^s) = \textstyle\frac{1}{\pi^t_y}\E_x\left[ p_y^t(x)\, \ell\left( y , f^s(x) \right)\right]; \\ 
&\hat{\phi}_y(f^s) = \textstyle\frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\,\ell\left( y , f^s(x_i) \right).
\end{align*}
% \begin{align*}
% \phi_y(f^s) &= \textstyle\frac{1}{\pi^t_y}\E_x\left[ p_y^t(x)\, \ell\left( y , f^s(x) \right)\right]\\[-2pt]
% \hat{\phi}_y(f^s) &= \textstyle\frac{1}{\hat{\pi}^t_y}\frac{1}{n}\sum_{i=1}^n  p_y^t(x_i)\,\ell\left( y , f^s(x_i) \right).
% \end{align*}
Then for teacher $p^t$ and student $f^s$:
% for any $\delta \in (0,1)$,
% with probability $\leq 1 - \delta$ (over draw of $S$), 
% {\color{green}TBD: Derive $\Delta(n,\delta)$}
\begin{align*}
|\hat{L}^\robd(f^s) - L^\rob(f^s)| \leq &\underbrace{
    B\max_{y \in [m]} \E_x\left[
        \left| \frac{p^t_y(x)}{\pi^t_y} - \frac{\eta_y(x)}{\pi_y} \right|\right]}_{\text{Calibration error}} \\
        &+ \underbrace{
\max_{y \in [m]} \big|\phi_y(f^s) - \hat{\phi}_y(f^s)\big|}_{\text{Estimation error}}.
\end{align*}
% \begin{align*}
% |\hat{L}^\robd(f^s) - L^\rob(f^s)|
% &\leq
% \underbrace{
%     B\max_{y \in [m]} \E_x\left[
%         \left| \frac{p^t_y(x)}{\pi^t_y} - \frac{\eta_y(x)}{\pi_y} \right|\right]}_{\text{Calibration error}} \\
% & + \underbrace{
% \max_{y \in [m]} \big|\phi_y(f^s) - \hat{\phi}_y(f^s)\big|}_{\text{Estimation error}}.
% \end{align*}
% where $\pi^t_y = \E_x\left[p^t_y(x)\right]$.
\end{theorem}
%
% See Appendix \ref{app:proof-good-teacher} for the proof.

The \emph{calibration error} %in the above bound 
captures 
how well the teacher's predictions mimic the conditional-class distribution $\eta(x) \in \Delta_m$, 
up to per-class normalizations $\pi$.
This suggests that even if $p^t$ does not 
achieve good worst-class performance, as long as
it is \textit{well-calibrated} within each class (as measured by the calibration error), 
it will serve as a good teacher.

The \emph{estimation error} captures how well the teacher aids in the student's out-of-sample generalization. 
The prior work by 
\citet{menon2021statistical} study  this question in detail for 
the standard student objective, and 
provide a bound that depends on the variance induced by the teacher's predictions on the
student's objective: the lower the variance, the better the student's generalization.
In Appendix \ref{app:student-gen-bound}, we carry out a similar analysis  with the %per-class loss terms 
estimation error
in the theorem.

 
% On the other hand, 
% it is not
% immediate if a teacher that optimizes the balanced or robust objectives
% will yield low approximation errors (as it is not clear what
% the normalization terms will evaluate to for them). %In our experiments, we probe further into the efficacies of these models as teachers.
% This is a  less stringent requirement on the teacher 
% compared to the condition in \citet{menon2021statistical}, where 
% the goal is to optimize the standard objective, and the teacher
% is evaluated based on how well it approximates the exact conditional probabilities $\eta(x)$.
% Consequently, any teacher of the form $p_y^t(x) = \alpha_y\eta_y(x)$, 
% for scaling factor $\alpha \in \R_+^m$, would yield an approximation error of zero. Note that
% this is indeed the case with each of the Bayes-optimal scorers in Theorem \ref{thm:bayes}. 
% This suggests that one can train a teacher by optimizing any of the objectives in \eqref{eq:std}--\eqref{eq:robust},
% apply a softmax transformation to them to predict probabilities, 
% and have the student optimize the distilled robust objective in \eqref{eq:robust-distilled-empirical} against
% the teacher's probabilities.\todo{This claim is not entirely correct when the teacher scores are normalized to sum to 1. TODO[Hari]: rephrase this!}


% \ref{thm:good-teacher} (more details in Appendix \ref{app:student-gen-bound}).
%In Appendix \ref{app:gen-bound}, we extend results from \citet{menon2021statistical} to derive a uniform convergence bound on  $|\hat{L}^\robd(f) - L^\robd(f)|$ 
% in terms of the variance of the loss induced by the teacher's predictions. 
% The lower the variance the faster is the convergence of the empirical risk to the population risk.

% \begin{theorem}
% \textup{\color{green}TBD}
% \todo{Add generalization bound on $|\hat{L}^\robd(f) - L^\robd(f)|$}
% \label{thm:gen-bound}
% \end{theorem}


\textbf{Calibration and worst-case error.} We illustrate how, perhaps counterintuitively, a teacher with low worst-class accuracy might still have scores $p^t$ that are well calibrated to match the true conditional-class distributions $\eta$. For this, we use a hypothetical ``image classification'' task with labels $y \in \{\text{cat}, \text{panda}, \text{other}\}$, and a single one-dimensional feature $x \in [0,1]$ representing the fraction of black pixels in the image, uniformly distributed over the interval. Suppose the solid lines in Figure \ref{fig:conditional_preds} below give the conditional-class distributions $\eta_y(x)$ for the cat and panda classes (pandas are rarer than cats in the dataset, with $\pi_{\text{cat}} = \frac{1}{2}$ and $\pi_{\text{panda}} = \frac{1}{4}$). 
Suppose the dashed lines in Figure \ref{fig:conditional_preds} also give hypothetical teacher model scores $p^t_y(x)$, where $p^t_{\text{cat}}(x) = 2\eta_{\text{cat}}(x)$, and $p^t_{\text{panda}}(x) = \frac{1}{2}\eta_{\text{panda}}(x)$ (these arbitrary teacher scores do not necessarily correspond to softmax outputs from a neural network). This teacher model always outputs a higher score for the cat label than the panda label. However, the model still satisfies the necessary calibration property: $\frac{p^t_y(x)}{E_x[p^t_y(x)]} = \frac{\eta_y(x)}{\pi_y}\;\; \text{ for } y \in \{\text{cat}, \text{panda}\},$
% \begin{equation*}
%     \frac{p^t_y(x)}{E_x[p^t_y(x)]} = \frac{\eta_y(x)}{\pi_y}\;\; \text{ for } y \in \{\text{cat}, \text{panda}\},
% \end{equation*}
despite the fact that the argmax predictions from this model has zero recall for the panda class. % if the predicted class were the class with the max score. 
This illustrates that the important property of the teacher's scores is how well they mimic the \textit{shape} of the conditional-class distributions, and not necessarily their worst-class predictive accuracy.

\begin{figure}[!ht]
    \centering
    \includegraphics[width=0.9\columnwidth]{calibration_example}
    \caption{Hypothetical conditional-class distributions $\eta_y(x)$ and trained model scores $p^t_y(x)$ for $y \in \{\text{cat}, \text{panda}\}$.}
    \label{fig:conditional_preds}
\end{figure}

\textbf{Relation to Bayes-optimal scorers.} 
When the teacher outputs the conditional-class probabilities, 
i.e.\ $p^t(x) = \eta(x)$, the  calibration error is trivially zero (recall that the normalization term
$\pi^t_y = \pi_y$ in this case). 
Theorem \ref{thm:bayes} shows that the Bayes-optimal scorer for the standard cross-entropy loss achieves this; however, in practice with finite data and model class limitations, a teacher trained with the cross-entropy loss is often far from approximating $\eta(x)$ exactly.
% We  know from Theorem \ref{thm:bayes} that this would be the case with a teacher trained to optimize the standard cross-entropy objective (provided we use an unrestricted model class). 
% In practice, however, we do not expect the teacher to approximate $\eta(x)$ very well,
% this opens the door for training the teacher with the other objectives described in Section \ref{sec:prelims}, each
% of which encourage the teacher to approximate a scaled (normalized) version of $\eta(x)$. 
% While Theorem \ref{thm:bayes} says that the Bayes optimal scorer satisfies $p^t(x) = \eta(x)$ for the standard cross-entropy objective with an unrestricted model class, 
In practice, it remains an open question what methodology might produce a teacher that most closely mimics these conditional-class distribution shapes for all classes in finite samples. For example, while the standard cross-entropy objective might lead to well calibrated model scores for a majority class, the scores may not match for rare classes.
%after running SGD for a finite number of steps.
Our experiments explored training with different losses from Section \ref{sec:prelims} that encourage the teacher to approximate scaled versions of $\eta(x)$; however, future exploration of other practical training possibilities would be interesting to compare.
% Our experiments explore training the teacher with different losses described in Section \ref{sec:prelims} that trade off between standard and robust losses, each of which encourages the teacher to approximate a scaled (normalized) version of $\eta(x)$. However, future exploration of other practical training possibilities would be interesting to compare.
% different objectives that trade off between standard and robust losses; however, future exploration of other practical training possibilities would be interesting to compare.

\textbf{Robustness guarantee for  student.}
% \todo{To be done}
%Having chosen a teacher model $p^t$, we would next like to understand the  form of the optimal student for $\hat{L}^\robd(f)$.
% (in the limit of infinite training samples).
%We next seek to understand if the student can match or outperform the teacher's worst-class performance.
We next provide robustness guarantees
for the student output by Algorithm \ref{algo:dro} 
in terms of the calibration and estimation errors
described above.
We do so for a fixed teacher $p^t$, and a \emph{self-distillation} setup where 
the student is chosen from the same function class $\cF$ as the teacher,
and can thus exactly mimic the teacher's predictions.
% Under this setup, 
% we provide robustness gurarantees
% for the student output by Algorithm \ref{algo:dro} 
% in terms of the approximation and estimation errors
% described above. %, and the convergence rate
%of the exponentiated gradient (EG) ascent algorithm. 

%We assume a fixed teacher $p^t$ and focus 
% on a \emph{self-distillation} setting where both
% the student and teacher are chosen from the same hypothesis class $\cF$. 
\begin{proposition}
\label{prop:student-form}
% For a fixed teacher $p^t: \X \> \Delta_m$, when $\ell$ is the cross-entropy loss,
% the minimizer of $\hat{L}^\robd(f)$ in \eqref{eq:robust-distilled-empirical} over all 
% measurable functions $f: \X \> \R^m$ is of the form:
% $$
% \softmax_y(f^*(x')) \propto \gamma_y p^t_y(x'),~~\forall (x', y') \in S,
% $$
%  for some class-specific constants $\gamma_1, \ldots, \gamma_m \in \R$.
Suppose $p^t \in \cF$ and $\cF$ is closed under linear transformations. Let $\bar{\lambda}_y = (\prod_{k=1}^K {\lambda_y^k}/{\pi^t_y})^{1/K}, \forall y$.
%, \forall y$. 
Then the scoring function $\bar{f}^s(x) = \frac{1}{K} \sum_{k=1}^K f^{k}(x)$ output by Alg. \ref{algo:dro} is of the form: $\softmax_j(\bar{f}^s(x)) \propto \bar{\lambda}_j p_j^t(x),~\forall j \in [m],\, \forall (x, y) \in S.$
%is of the form:
    % \[
    %     \softmax_j(\bar{f}^s(x)) \propto \bar{\lambda}_j p_j^t(x),~\forall j \in [m],\, \forall (x, y) \in S.
    % \]
% where $\bar{\lambda}_j = \left(\prod_{k=1}^K {\lambda_j^k}/{\pi^t_j}\right)^{1/K}$. 
\end{proposition}
%
\begin{theorem}
\label{thm:dro}
Suppose $p^t \in \cF$ and $\cF$ is closed under linear transformations. 
Suppose
$\ell$ is the cross-entropy loss $\ell^{\xent}$,
$\ell(y, z) \leq B$
and $\max_{y \in [m]}\frac{1}{\pi^t_y} \leq Z$,
for some $B, Z > 0$. 
Furthermore, suppose for any $\delta \in (0,1)$, the following bound holds on the
estimation error in Theorem \ref{thm:good-teacher}:
with probability at least $1 - \delta$ (over draw of $S \sim D^n$),
$\forall f \in \cF$, $\textstyle
\max_{y \in [m]} \big|\phi_y(f) - \hat{\phi}_y(f)\big| \leq \Delta(n, \delta),$
% \[
% \textstyle
% \max_{y \in [m]} \big|\phi_y(f) - \hat{\phi}_y(f)\big| \leq \Delta(n, \delta),
% \]
for some $\Delta(n, \delta) \in \R_+$ that
is increasing in $1/\delta$, and goes to 0 as $n \> \infty$. 
Then when the step size $\gamma = \frac{1}{2BZ}\sqrt{\frac{\log(m)}{{K}}}$
and $n^\val \geq 8Z\log(2m/\delta)$, we have that with
probability at least $1-\delta$ (over draw of $S \sim D^n$ and $S^\val \sim D^{n^\val}$),
\begin{align*}
L^\rob(\bar{f}^s) \,\leq\,
    \min_{f \in \cF}L^\rob(f)
+
        \underbrace{2\Delta(n^\val, \delta/2) \,+\,  2\Delta(n, \delta/2)}_{\text{Estimation error}} \\
+
\underbrace{
2B\max_{y \in [m]} \E_x\left[
        \left| \frac{p^t_y(x)}{\pi^t_y} \,-\, \frac{\eta_y(x)}{\pi_y} \right|\right]}_{\text{Calibration error}}
        \,+\,
        \underbrace{4BZ\sqrt{\frac{\log(m)}{{K}}}}_{\text{EG convergence}}.
\end{align*}
\end{theorem}


% The proof is provided in Appendix \ref{app:convergence-dro}, and 
% builds on the
% convergence guarantee for
% exponentiated gradient (EG) ascent \citep{shalev2011online}
% and calibration properties of the margin loss \citep{narasimhan2021training}. %Therefore when the student optimizes over e.g.\ an over-parameterized function class, %and has access to unlimited training data, 

% Theorem \ref{thm:dro}
Proposition \ref{prop:student-form} shows the student 
not only learns to mimic the teacher on the training set, but improves upon it by making per-class adjustments to its predictions.
 Theorem \ref{thm:dro} shows that these adjustments are chosen to close-in on the
gap to the optimal robust scorer in $\cF$.
% The form of the student
%  suggests that it can not only match the teacher's performance,
% but can potentially improve upon it by making %scaling 
% adjustments to its scores. 
However, the student's convergence to the optimal scorer in $\cF$
would still be limited by the teacher's calibration error:
even when the sample sizes and number of iterations $n, n^\val, K \> \infty$,
the student's optimality gap may still be non-zero when the teacher is  poorly calibrated. 


%%% Move to appendix
\if 0
\textbf{Connection to post-hoc adjustment.}\
%\label{sec:post-hoc-adjustment}
The form of the student %in the self-distillation setup 
in Proposition \ref{prop:student-form} raises an interesting question. Instead of training an explicit student model, 
why not directly construct a new scoring model by making post-hoc adjustments
to the teacher's predictions? Specifically, one could optimize over functions of the form $f^s_y(x) = \log(\gamma_y p^t_y(x)),$ where the teacher $p^t$ is fixed, and pick 
the coefficients $\gamma \in \R^m$ so that resulting scoring function yields the best worst-class accuracy on a held-out dataset.  
This simple \emph{post-hoc adjustment} strategy 
may not be feasible if the goal is to distill to a student that is considerably smaller than the teacher. Often, this is the case in settings where distillation is used as a compression technique.
% feasible if the teacher model is too complex to deploy in practice, and one desires a student with smaller complexity. 
Yet, this post-hoc method %may still be useful in providing a rough estimate of the worst-case performance that a student can hope to 
% achieve when using a particular teacher, and will 
serves as good baseline to compare with.

\textbf{One-hot validation labels.}\ In Appendix \ref{app:one-hot-vali}, we discuss robustness guarantees for Algorithm \ref{algo:dro} when it uses one-hot labels in the validation set instead of the teacher labels.
\fi

%%%%% CONCLUSION %%%%%
\section{Conclusions and future work}
\label{sec:conclusion}
We have demonstrated the value of applying different combinations of teacher/student objectives, not only for improving worst-class accuracy, but also to achieve efficient trade-offs between average and worst-class accuracy. Surprisingly, the teacher and students’ objective functions can interact with each other in nontrivial ways: for example, applying a robust objective to both the teacher and the student does not always achieve the best worst-class accuracy (Table 1).
Further exploring the trade-off between worst-class and average accuracy, we provided simple modifications to the teacher and student objectives that boosted worst-class accuracy with less degradation in average accuracy than prior methods that focus on worst-class accuracy. This confirms the key takeaway that the teacher's objective plays a crucial role in the student's robustness.

% Incorporating some amount of robustness in the teacher results in Pareto improvements over using just a standard teacher, and even simple hyperparameter selection on $\alpha^t, \alpha^s$ yields competitive results with state-of-the-art methods.

In a broader sense, our theory provides better understanding of the interplay between teacher and student objectives, and thus serves as a starting point for further development of methods to modify both the teacher and students' objectives jointly. An interesting future avenue for exploration would be to extend our distillation setup to incorporate other forms of teacher supervision such as intermediate embeddings or ensembled scores (e.g., \citet{iscen2021class}).

Training efficiency is another avenue for improvement, and future work in reducing the hyperparameter search space would be practically valuable. For settings where teacher retraining is particularly expensive, one could modify a given fixed teacher with some form of post-hoc logit adjustment \citep{narasimhan2021training}, or only fine-tune a subset of the teacher parameters with different values of $\alpha^t$. These reductions in computational cost would improve the practicality of joint exploration of teacher and student objectives.

\begin{acknowledgements}
We are grateful to Luca Zappella for the detailed constructive feedback on this manuscript. We also thank Erik Vee for valuable discussions and pointers.
\end{acknowledgements}

% References
\bibliography{main}

\end{document}
