% \documentclass{uai2023} % for initial submission
\documentclass[accepted]{fathullah_460} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2023} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2023} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
%\usepackage[american]{babel}
\usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
\usepackage{booktabs} % commands to create good-looking tables
\usepackage{tikz} % nice language for creating drawings and diagrams

% Maths equations
\usepackage{mathtools}
\usepackage{amssymb}
\usepackage{amsmath}
\usepackage{amsfonts}
\usepackage{amsthm}
\usepackage{commath}
\usepackage{bm}
\usepackage{physics}
\usepackage{multicol}
\usepackage{bbm}
\usepackage{scalerel}
\usepackage{comment}

% Formatting tables and figures
\usepackage{multirow, booktabs}
\usepackage[bf]{caption}
\setlength{\captionmargin}{10pt}
\usepackage{subcaption}
\usepackage{makecell}
\usepackage{graphicx}
\usepackage{stfloats}
\usepackage[all=normal, mathspacing=tight, floats=tight, paragraphs=tight]{savetrees}
\DeclarePairedDelimiterX{\infdivx}[2]{(}{)}{%
	#1\;\delimsize\|\;#2%
}
% Notation
\newcommand{\GP}{\mathcal{GP}}
\newcommand{\tP}{{\tt P}}
\newcommand{\tQ}{{\tt Q}}
\newcommand{\tp}{{\tt p}}
\newcommand{\tq}{{\tt q}}
\newcommand{\tpost}{{\tt p}(\bm\theta\vert\mathcal{D})}
\newcommand{\tprior}{{\tt p}(\bm\theta)}
\newcommand{\bhy}{\hat{\bm y}}
\newcommand{\hy}{\hat{y}}
\newcommand{\KL}{{\tt KL}\infdivx}
\newcommand{\Softmax}{{\tt Softmax}}
\newcommand{\LSoftmax}{{\tt LogSoftmax}}
\newcommand{\LogSumExp}{{\tt LogSumExp}}
\newcommand{\Dir}{{\tt Dir}}
\newcommand{\GLEU}{{\tt GLEU}}
\newcommand{\AUC}{{\tt AUC}}
\newcommand{\AUCRR}{{\tt AUC_{RR}}}

% Custom commands 
\DeclareMathOperator*{\argmax}{arg\,max}
\DeclareMathOperator*{\argmin}{arg\,min}
\DeclareMathOperator*{\minimise}{\,minimise}
\newcommand{\tpm}{\tiny{$\pm$} }

\DeclarePairedDelimiter\ceil{\lceil}{\rceil}
\DeclarePairedDelimiter\floor{\lfloor}{\rfloor}
\providecommand{\tabularnewline}{\\}

\usepackage{pifont}
\newcommand{\cmark}{\ding{51}}%
\newcommand{\xmark}{\ding{55}}%

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newcommand{\swap}[3][-]{#3#1#2} % just an example

\title{Logit-Based Ensemble Distribution Distillation for \\Robust Autoregressive Sequence Uncertainties}

% The standard author block has changed for UAI 2023 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is automatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1]{\href{mailto:<yf286@cam.ac.uk>?Subject=L-EDD UAI 2023}{Yassir~Fathullah}{}}
\author[2]{\href{mailto:<g.xia21@imperial.ac.uk}{Guoxuan~Xia}{}}
\author[1]{Mark~J.~F.~Gales}
% Add affiliations after the authors
\affil[1]{%
    Engineering Department\\
    University of Cambridge\\
    UK
}
\affil[2]{%
    Department of Electrical \& Electronic Engineering \\
    Imperial College London\\
    UK
}
  
\begin{document}
\maketitle

\begin{abstract}

    %% Guoxuan abstract
    Efficiently and reliably estimating uncertainty is an important objective in deep learning. It is especially pertinent to autoregressive sequence tasks, where training and inference costs are typically very high. However, existing research has predominantly focused on tasks with static data such as image classification. 
    In this work, we investigate Ensemble Distribution Distillation (EDD) applied to large-scale natural language sequence-to-sequence data. EDD aims to compress the superior uncertainty performance of an expensive (teacher) ensemble into a cheaper (student) single model. Importantly, the ability to separate knowledge (epistemic) and data (aleatoric) uncertainty is retained. Existing probability-space approaches to EDD, however, are difficult to scale to large vocabularies. 
    We show, for modern transformer architectures on large-scale translation tasks, that modelling the ensemble \textit{logits}, instead of softmax probabilities, leads to significantly better students. Moreover, the students surprisingly even \textit{outperform Deep Ensembles} by up to $\sim$10\% AUROC on out-of-distribution detection, whilst matching them at in-distribution translation. 
    % Furthermore, through snapshot ensembling, we are also able to greatly reduce training costs compared to standard EDD.

    
\end{abstract}

\section{Introduction}
\label{sec:intro}

\begin{comment}
    
\begin{figure}
    \centering
    \includegraphics[width=.9\linewidth]{figures/EDD_preview.pdf}
    \caption{OOD detection performance (\%AUROC) against translation performance (BLEU) for different approaches. Higher is better for both. Our approach, Laplace Logit-Based Ensemble Distribution Distillation  (L-EDD Lap.) is able to outperform even the Deep Ensemble in both tasks, in particular, OOD detection. The ID data is En-De WMT’16, OOD data is Khresmoi-Summary, and the model used is Transformer.}
    \label{fig:preview}
\end{figure}

\end{comment}


The ability to produce reliable estimates of uncertainty is important to many tasks in deep learning. When it is costly to make mistakes, a model should know when to discard a prediction or defer it to a human expert. Although there is a significant body of research in this area \citep{ovadia_trust,duku,yang2021oodsurvey}, it tends to focus on \textit{static} data, where outputs are of a fixed dimension, with approaches most commonly evaluated on image classification \citep{pmlr-v97-geifman19a, sngp, yang2022openood,  Moon2020ConfidenceAwareLF, Xia_2022_ACCV}. In contrast, in this work, we aim to investigate uncertainty estimation for sequence prediction tasks, such as machine translation, which is a relatively under-explored domain \citep{structured}.

Large attention-based autoregressive neural networks have recently emerged as the most competitive approach to many structured sequence-prediction tasks, especially in translation \citep{atten, aiayn, scaling-nmt}, and are increasingly being used in practice. However, as the computational and memory costs of these modern approaches are typically very large, it is particularly important that approaches for improving the quality of uncertainties should be \textit{efficient}. We will focus on one such efficient approach to better uncertainties, Ensemble Distribution Distillation (EDD) \citep{endd}.

Ensembling multiple neural networks trained using different random seeds is a well-established approach for boosting uncertainty performance \citep{deepens}. Deep Ensembles have been shown to be effective over a wide range of data, tasks, and evaluation metrics \citep{ovadia_trust, structured, Kim2021AUB, gustafsson2020evaluating}. Moreover, they are naturally able to decompose total uncertainty into knowledge (epistemic) and data (aleatoric) uncertainty \citep{duku}, which can be useful for different tasks such as active learning \citep{pmlr-v70-gal17a, seq-al2}, reinforcement learning \citep{decomp}, and out-of-distribution detection \citep{structured}. However, Deep Ensembles suffer from costs that scale linearly with the number of members. EDD aims to tackle this by using Knowledge Distillation (KD) \citep{kd} to compress the (teacher) ensemble into a more efficient (student) single model. Crucially, EDD not only has the student learn the predictions of the ensemble but also the \textit{distribution} over individual ensemble member outputs. By explicitly modelling the diversity over the ensemble, the student is able to express knowledge and data uncertainty independently just like the teacher ensemble \citep{endd}.



However, EDD is not without its challenges. Prior work has shown that EDD suffers from optimisation issues, meaning it can be difficult to scale to confident ensembles with large label spaces. Thus, EDD requires a number of practical modifications in order to be applied to large-scale tasks such as machine translation \citep{seq-endd, ryabin}. Despite these challenges, the concept behind EDD remains a promising approach for training single autoregressive models with smaller footprints and the ability to estimate high-quality, robust uncertainties.

\textbf{Summary of contributions:} In this paper, we focus on an underexplored area of uncertainty estimation: robust and efficient autoregressive sequence uncertainties. Specifically, we address the drawbacks of sequence EDD by using \textit{logit-based} ensemble distribution distillation (L-EDD). Instead of training a student to distribution distil the information from an ensemble in the probability/softmax space, we teach it to perform the same task in the pre-softmax logit space. Experiments on the En-De WMT'16 and En-Ru WMT'20 machine translation tasks show that L-EDD, in particular when using a Laplace distribution, produces strong estimates of sequence uncertainty. L-EDD is able to outperform EDD and surprisingly even Deep Ensembles on out-of-distribution (OOD) detection and match them for translation quality.
% (see Figure \ref{fig:preview}). 
%
Furthermore, by using Snapshot Ensembles \citep{snapshot}, we are able to greatly reduce the overall training costs of EDD compared to using a Deep Ensemble teacher.


\section{Background}
\label{sec:background}

In this section, we review ensemble-based uncertainty estimation. We follow with a discussion of how the limitations of ensembles can be addressed using recently developed distillation techniques for autoregressive sequence tasks such as machine translation. 

\subsection{Uncertainty Estimation}
\label{ssec:uncertainty}

We adopt a Bayesian perspective on ensembles as this offers a flexible framework within which uncertainties have an information-theoretic justification. The posterior over model parameters $\tpost$ is derived given some observed (training) data, ${\cal D}$. Unfortunately, the posterior is often intractable and cannot be derived for large non-linear networks. Alternatively an approximation $\tq(\bm\theta) \approx \tpost$ can be used. Samples from this approximate distribution can then be drawn to generate an ensemble of models. 

Take an ensemble $\{\tP(\bm y\vert\bm x, \bm\theta^{(m)})\}_{m = 1}^{M}$ sampled from an approximate posterior $\tq(\bm\theta)$ where each model maps a \textit{variable-length} input $\bm x \in \mathcal{X}$ into a \textit{variable-length} output $\bm y \in \mathcal{Y}$ of discrete units. The predictive distribution is obtained by:
%
\begin{align}
	\label{eq:predictive}
	\tP(\bm y\vert\bm x, \mathcal{D}) = \mathbb{E}_{\tq(\bm\theta)} \left[ \tP(\bm y\vert\bm x, \bm\theta) \right].
\end{align}
%
From this predictive distribution, a measure of total uncertainty can be estimated using the entropy:
%
% \begin{align}
% 	\label{eq:total}
% 	\mathcal{H}\left[ \tP(\bm y\vert\bm x, \mathcal{D}) \right] = \mathbb{E}_{\tP(\bm y\vert\bm x, \mathcal{D})}\left[ \ln \frac{1}{\tP(\bm y\vert\bm x, \mathcal{D})} \right]
% \end{align}
\begin{align}
	\label{eq:total}
	\mathcal{H}\left[ \tP(\bm y\vert\bm x, \mathcal{D}) \right] = \mathbb{E}_{\tP(\bm y\vert\bm x, \mathcal{D})}\left[ -\ln \tP(\bm y\vert\bm x, \mathcal{D}) \right].
\end{align}
%
Furthermore, a measure of disagreement between models, also referred to as \textit{knowledge} or \textit{epistemic} uncertainty, can be estimated by using mutual information between $\bm y$ and $\bm\theta$:
%
\begin{align}
	\label{eq:knowledge}
	\begin{split}
		\mathcal{I}\left[ \bm y, \bm\theta \vert \bm x, \mathcal{D} \right] = \mathbb{E}_{\tq(\bm\theta)}\left[ \KL[\big]{\tP(\bm y\vert\bm x, \bm\theta)}{\tP(\bm y\vert\bm x, \mathcal{D})} \right].
	\end{split}
\end{align}
%
This estimate can also be decomposed into a measure of total and data (aleatoric) uncertainty, as mentioned in \citet{structured}. There are also many other potential measures of knowledge uncertainty such as expected pairwise KL-divergence or reverse mutual information \citep{structured}, however, for the sake of simplicity we restrict our focus to the already mentioned eq. (\ref{eq:total}) and (\ref{eq:knowledge}) since these represent uncertainties of differing natures.

\textbf{Limitations}: The discussion has so far assumed one can enumerate all possible variable-length outputs $\bm y \in \mathcal{Y}$ which is not tractable in autoregressive sequence tasks. Instead, one can approximate the uncertainties by monte-carlo methods \citep{notin2021improving} and utilising the autoregressive structure of predictions \citep{structured}:
%
\begin{align}
	\label{eq:ar}
	\begin{split}
		\tP(\bm y\vert\bm x, \bm\theta) = \prod_{l = 1}^{L} \tP(y_l \vert \bm y_{<l}, \bm x, \bm\theta).
	\end{split}
\end{align}
%
We refer to \citet{structured} for an in-depth discussion and analysis of approximations for predictive entropy and mutual information for autoregressive prediction.


\subsection{Knowledge Distillation}
\label{ssec:distillation}

Ensembles $\{\tP(\bm y\vert\bm x, \bm\theta^{(m)})\}_{m = 1}^{M}$ sampled from some posterior can be computationally demanding. One approach to efficiently exploit the information of the ensemble is to use Knowledge Distillation (KD) to yield a single student model \citep{kd, seq-kd}. 

Given a reference data pair $(\bm x, \bm y) \sim \tilde{\tp}(\bm x, \bm y)$, a standard model might be trained using negative log-likelihood (NLL): 
%
\begin{align}
	\label{eq:nll}
	\mathcal{L}_{\tt NLL}(\bm\theta) = - \frac{1}{L} \sum_{l = 1}^{L} \ln\tP(y_l \vert \bm y_{<l}, \bm x, \bm\theta).
\end{align}
%
This is referred to as teacher-forcing since during training the model makes predictions at step $l$ conditioned on the true output $\bm y_{<l}$ (rather than its own previous predictions). Similarly, a student model with parameters $\bm\phi$ can be trained to emulate a teacher ensemble by additionally using \textit{average} ensemble categorical/softmax outputs $\bm\pi_l, \pi_{l, k} = \tP(y_l = k \vert \bm y_{<l}, \bm x, \mathcal{D})$ as soft labels:
%
\begin{align}
	\label{eq:distillation}
	\begin{split}
		\mathcal{L}_{\tt KL}(\bm\phi) = \frac{1}{L}\sum_{l = 1}^{L} \KL[\big]{\bm\pi_l}{\tP(y_l \vert \bm y_{<l}, \bm x, \bm\phi)}.
	\end{split}
\end{align}
%
However in practice, one optimises a convex combination of the likelihood and KL-divergence losses $\mathcal{L}_{\tt KD}(\bm\phi) = \lambda \mathcal{L}_{\tt NLL}(\bm\phi) + (1-\lambda)\mathcal{L}_{\tt KL}(\bm\phi), \medspace \lambda \in [0, 1]$ for added supervision and stability. The probability mass functions in the KL-divergence can also be temperature scaled to improve optimisation \citep{kd}. Note that this criterion is only considered for the teacher-forcing case, more sophisticated distillation approaches exist, by sampling $(\bm x, \bm y)$ from alternative distributions, but is outside the scope for this work, see \citet{seq-kd, gen-student-th, ood-training-1, ood-training-2} for details.


\subsection{Ensemble Distribution Distillation}
\label{ssec:endd}

Whilst KD has been successful in many sequence tasks, the resulting student is not able to estimate knowledge uncertainty, since it only models the average ensemble output. To avoid this issue, \citet{endd, ryabin} consider the task of distilling the \textit{distribution} of sequence ensemble predictions onto a single student. This allows the student to retain both predictive performance and information about ensemble diversity. 

To explain the mechanics behind Ensemble Distribution Distillation (EDD), consider modelling a distribution over autoregressive ensemble predictions, in which all $M$ ensemble members share the same back-history $\bm y_{<l}$:
%
\begin{align}
    \label{eq:predictionsens}
    \{\bm\pi_{l}^{(m)}\}_{m=1}^M, \medspace \pi_{l, k}^{(m)} = \tP(y_l = \omega_k\vert\bm y_{<l}, \bm x, \bm\theta^{(m)}).
\end{align}
%
Now let an autoregressive student predict the parameters $\bm\alpha_l$ of a Dirichlet distribution $\Dir(\bm\pi_l\vert\bm\alpha_l) = \tp(\bm\pi_l\vert\bm y_{<l}, \bm x, \bm\phi)$. Since the Dirichlet models
% is a prior to and can model
a distribution over categorical distributions it is an ideal candidate for this task. The distribution distillation loss of such a model is then simply the result of (negative) log-likelihood:
%
\begin{align}
    \label{eq:dirichletdd}
    \begin{split}
        \mathcal{L}_{\tt NLL}^{\tt DD}(\bm\phi) 
        & = -\frac{1}{MLK} \sum_{m, l} \ln\Dir(\bm \pi_l^{(m)}\vert\bm\alpha_l) \\
        & \equiv \frac{1}{LK} \sum_{l} \Big( \ln B(\bm\alpha_l) - \sum_{k} \alpha_{l, k}\ln\tilde{\pi}_{l, k} \Big), \
    \end{split}
\end{align}
%
where $K$ is the number of classes, $B(\bm\alpha)$ is the beta function and $\tilde{\pi}_{l, k}$ is the geometric average of the individual ensemble softmax probabilities in Equation (\ref{eq:predictionsens}). 

Whilst this approach was shown to be promising on a small-scale image classification task in \cite{endd}, following work \citep{seq-endd, ryabin} found that direct application of Equation (\ref{eq:dirichletdd}) encounters optimisation issues when scaled to larger label spaces. 
% The main issue is based on how the resulting gradients from the distribution distillation loss correlate with the teacher probabilities. 
This arises from the way classwise loss gradients are related to teacher class probabilities.
It turns out that, unlike standard distillation, the loss in Equation (\ref{eq:dirichletdd}) induces small gradients for (important) high-probability classes and large gradients for (unimportant) low-probability classes. This negatively affects convergence as the number of low-probability classes increases.
% skewing the learning process and making it much more challenging to apply in most but the simplest tasks. 
\citet{ryabin} proposed an approach for scaling Dirichlet EDD, where the student aims to minimise a normalized reverse KL-divergence to a \textit{proxy} Dirichlet, which will be used as a baseline in this work.


\section{Sequence Logit-based EDD}
\label{sec:ledd}



In sequence tasks with a large number of classes, which commonly occurs in speech recognition and machine translation, the output categorical distributions are often very sparse and concentrated. Therefore, it often becomes highly challenging to apply EDD to tasks of this nature. On the other hand, KD has been shown to work well for larger tasks \citep{seq-kd, seq-kd-app1, seq-kd-app2, seq-kd-app3}, but since it only models the average teacher predictions, it cannot estimate data and knowledge uncertainties that are important for many downstream tasks such as out-of-distribution detection.

In this section, we describe a \textit{Logit-based} Ensemble Distribution Distillation (L-EDD) approach for autoregressive models which addresses the drawbacks of both KD and EDD in a single consistent framework and is scalable to sequence problems with a large number of classes. Consider a set of logits produced by an ensemble:
%
\begin{align}
    \label{eq:logits}
    \{\bm z_{l}^{(m)}\}_{m=1}^M, \medspace \bm\pi_{l}^{(m)} = \Softmax(\bm z_{l}^{(m)}).
\end{align}
%
Traditional distillation approaches thereafter use the logits to produce categorical probability distributions by applying the softmax function. However, instead of operating in the probability space, we propose training a student, with model parameters $\bm\phi$, to directly model the logit space by predicting the mean $\bm\mu_l$ and scale $\bm\sigma_l$ parameters of a diagonal Laplace distribution:
%
\begin{align}
    \label{eq:laplace-logit}
    \begin{split}
        \tp(\bm z \vert \bm y_{<l}, \bm x, \bm\phi) 
        & = {\tt Lap}(\bm z \vert \bm\mu_{l}, \bm\sigma_{l}) \\
        & = \prod_{k} \frac{1}{2\sigma_{l,k}} \exp{-\frac{\abs{z_k - \mu_{l,k}}}{\sigma_{l,k}}}.
    \end{split}
\end{align}
%
Because we opt for a diagonal distribution, sampling is parallelisable, highly efficient, and straightforward and allows for the estimation of uncertainties in exactly the same manner as in standard ensembles. Additionally, significantly fewer parameters are required compared to using a fully-specified covariance matrix. Another benefit of the chosen distribution is the long tails which make the Laplace robust to outliers, unlike the Gaussian distribution. This robustness also makes it a natural choice for handling the early stages of training when the student model is randomly initialised and its output distribution substantially differs from the ensemble logits.

Furthermore, given the set of logits produced by an ensemble, the student model $\tp(\bm z \vert \bm y_{<l}, \bm x, \bm\phi)$ can be trained by straightforward application of log-likelihood training:
%
\begin{align}
    \label{eq:laplacedd}
    \begin{split}
        \mathcal{L}_{\tt NLL}^{\tt L-EDD}(\bm\phi) 
        & = -\frac{1}{MLK} \sum_{m, l} \ln{\tt Lap}(\bm z_l^{(m)} \vert \bm\mu_{l}, \bm\sigma_{l}) \\
        & \equiv \frac{1}{MLK} \sum_{m,l,k} \frac{\abs{z_{l,k}^{(m)} - \mu_{l,k}}}{\sigma_{l,k}} + \ln\sigma_{l,k}.
    \end{split}
\end{align}
%

We also perform experiments with a student (diagonal) Gaussian distribution on the logits, variations of which have been explored in static image classification \citep{s2d, mod-edd} but remained unexplored for autoregressive sequence tasks:
%
\begin{align}
    \label{eq:gaussian-logit}
    \begin{split}
        \tp(\bm z \vert \bm y_{<l}, & \bm x, \bm\phi) 
        = \mathcal{N}(\bm z \vert \bm\mu_{l}, \bm\sigma_{l}^2) \\
        & = \prod_{k} \frac{1}{(2\pi\sigma_{l,k}^2)^{\frac{1}{2}}} \exp{-\frac{(z_k - \mu_{l,k})^2}{2\sigma_{l,k}^2}}.
    \end{split}
\end{align}
%
Similar to all of the mentioned approaches, this system is also trained using the log-likelihood objective:
%
\begin{align}
    \label{eq:gaussiandd}
    \begin{split}
        \mathcal{L}_{\tt NLL}^{\tt L-EDD}(\bm\phi) 
        & = -\frac{1}{MLK} \sum_{m, l} \ln\mathcal{N}(\bm z_l^{(m)} \vert \bm\mu_{l}, \bm\sigma_{l}^2). 
        % \\
        % & \equiv \frac{1}{MLK} \sum_{m,l,k} \frac{(z_{l,k}^{(m)} - \mu_{l, k})^2}{2\sigma_{l,k}^2} + \ln\sigma_{l,k}
    \end{split}
\end{align}
%
The Gaussian distribution, which induces an L2-norm loss function is much more sensitive to outliers in the ensemble outputs. This student could potentially be more challenging to train, but should still be more stable than Dirichlet EDD.

\begin{comment}

Traditional distillation approaches thereafter use the logits to produce categorical probability distributions by applying the softmax function. However, instead of operating in the probability space, we train a student, with parameters $\bm\phi$, to directly model the logit space by predicting the mean $\bm\mu_l$ and standard deviation $\bm\sigma_l$ of a diagonal Gaussian distribution:
%
\begin{align}
    \label{eq:gaussian-logit}
    \begin{split}
        \tp(\bm z \vert \bm y_{<l}, & \bm x, \bm\phi) 
        = \mathcal{N}(\bm z \vert \bm\mu_{l}, \bm\sigma_{l}^2) \\
        & = \prod_{k} \frac{1}{\sqrt{2\pi\sigma_{l,k}^2}} \exp{-\frac{(z_k - \mu_{l,k})^2}{2\sigma_{l,k}^2}}
    \end{split}
\end{align}
%
Because we opt for a diagonal distribution, sampling is highly efficient and straightforward and allows for the estimation of uncertainties in exactly the same manner as in standard ensembles.

Furthermore, given the set of logits produced by an ensemble, the student model $\tp(\bm z \vert \bm y_{<l}, \bm x, \bm\phi)$ can be trained by straightforward application of maximum likelihood estimation:
%
\begin{align}
    \label{eq:gaussiandd}
    \begin{split}
        \mathcal{L}_{\tt NLL}^{\tt GDD}(\bm\phi) 
        & = -\frac{1}{MLK} \sum_{m, l} \ln\mathcal{N}(\bm z_l^{(m)} \vert \bm\mu_{l}, \bm\sigma_{l}^2) \\
        & \equiv \frac{1}{MLK} \sum_{m,l,k} \frac{(z_{l,k}^{(m)} - \mu_{l, k})^2}{2\sigma_{l,k}^2} + \ln\sigma_{l,k}
    \end{split}
\end{align}
%
In this case, the teacher ensemble can only induce large gradients when the logits significantly deviate from the coverage of the student Gaussian distribution. However, to further address this problem, we propose using a distribution which has 'long tails' and is naturally robust to outliers. There are several options but restrict ourselves to the diagonal Laplace distribution which can be used as a drop-in replacement for the diagonal Gaussian:
%
\begin{align}
    \label{eq:laplace-logit}
    \begin{split}
        \tp(\bm z \vert \bm y_{<l}, \bm x, \bm\phi) 
        & = {\tt Lap}(\bm z \vert \bm\mu_{l}, \bm\sigma_{l}) \\
        & = \prod_{k} \frac{1}{2\sigma_{l,k}} \exp{-\frac{\abs{z_k - \mu_{l,k}}}{\sigma_{l,k}}}
    \end{split}
\end{align}
%
This robustness also makes the Laplace distribution a natural choice for handling the early stages of training when the model is randomly initialised and outputs distributions wildly different from the ensemble logits. Similarly, this style of student can also be trained by the maximum likelihood principle by minimising:
%
\begin{align}
    \label{eq:laplacedd}
    \begin{split}
        \mathcal{L}_{\tt NLL}^{\tt L-EDD}(\bm\phi) 
        & = -\frac{1}{MLK} \sum_{m, l} \ln{\tt Lap}(\bm z_l^{(m)} \vert \bm\mu_{l}, \bm\sigma_{l}) \\
        & \equiv \frac{1}{MLK} \sum_{m,l,k} \frac{\abs{z_{l,k}^{(m)} - \mu_{l,k}}}{\sigma_{l,k}} + \ln\sigma_{l,k}
    \end{split}
\end{align}
%
\end{comment}

\subsection{Practical Considerations}

Since the softmax activation function is shift invariant,
\begin{align*}
    \Softmax(\bm z - \bm 1 b) = \Softmax(\bm z) \thickspace\thickspace \forall b \in \mathbb{R},
\end{align*}
one has to consider this property when performing distribution distillation. Ensemble members are unconstrained along $\bm 1$, and so can potentially vary wildly in the logit space, even if they give consistent softmax predictions. Therefore, logits are normalised by $\tilde{\bm z} = \bm z - \bm 1 \LogSumExp(\bm z)$. This particular normalisation scheme is not special and any choice of the normalisation constant such as ${\tt Max}(\bm z)$ or ${\tt Mean}(\bm z)$ would be valid. Next, to ensure that the student can be trained reliably, we interpolate the knowledge and distribution distillation losses $\mathcal{L}_{\tt KD}(\bm\phi) + \beta \mathcal{L}_{\tt NLL}^{\tt L-EDD}(\bm\phi)$ (see Eq. \ref{eq:laplacedd} and Sec. \ref{ssec:distillation}).

Furthermore, distributions in logit space often lead to analytically intractable expectations in probability space. The standard approach to circumvent this issue is by sampling from the distribution using monte-carlo approximations. However, in this paper, we opt for an approximative deterministic approach when computing the predictive distribution (e.g. when decoding):
%
\begin{align*}
    \tP(y_l \vert \bm y_{<l}, \bm x, \bm\phi) 
    & = \mathbb{E}_{\tp(\bm z_l \vert \bm y_{<l}, \bm x, \bm\phi)} \left[ \Softmax(\bm z_l) \right] \\
    & \approx \Softmax(\bm \mu_l),
\end{align*}
%
in which we approximate the expectation by just using the mean of the logit distribution. When performing downstream tasks that require uncertainties we revert to a stochastic sampling scheme to generate multiple predictions from the distribution.

\section{Experiments on Artificial Data}
\label{sec:toy-example}

\begin{figure*}[t!]
     \centering
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/scatter.pdf}}
                 
                 \caption{Dataset.}
                 \label{fig:toy-data}
                 
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/distillation-loss.pdf}}
                 
                 \caption{Distillation loss}
                 \label{fig:toy-dist-loss}
                 
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/dirichlet-loss.pdf}}
                 
                 \caption{Dirichlet loss}
                 \label{fig:toy-dir-loss}
                 
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/laplace-loss.pdf}}
                 
                 \caption{Laplace loss}
                 \label{fig:toy-lap-loss}
                 
             \end{subfigure}
             
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/ensemble-confidence.pdf}}
                 
                 \caption{Ensemble conf.}
                 \label{fig:toy-ens-conf}
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/distillation-confidence.pdf}}
                 
                 \caption{Distillation conf.}
                 \label{fig:toy-dist-conf}
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/dirichlet-confidence.pdf}}
                 
                 \caption{Dirichlet conf.}
                 \label{fig:toy-dir-conf}
             \end{subfigure}
             \hfill
             \begin{subfigure}{0.23\textwidth}
                 \centering
                 \vstretch{1.0}{\includegraphics[width=\textwidth]{figures-v2/laplace-confidence.pdf}}
                 
                 \caption{Laplace conf.}
                 \label{fig:toy-lap-conf}
             \end{subfigure}

     \caption{An artificial three-class classification problem with 1000 examples per class. The top row shows the loss surface contours for various distillation approaches; darker colours imply lower losses. The bottom row shows the corresponding confidence contours (the confidence scores are reported on the contour). This shows that Dirichlet-based EDD is unable to learn properly whilst our proposed Laplace L-EDD can imitate an ensemble.}
     \label{fig:toy}
\end{figure*}

This section investigates the proposed Laplace logit-based ensemble distribution distillation (L-EDD) technique on a static artificial dataset, see Figure \ref{fig:toy-data}. The dataset was generated by sampling 3000 data points from three isotropic Gaussian distributions equally. The location and standard deviation of the Gaussians were chosen such that there would be regions with significant overlap and regions where models can be highly confident.

In these exploratory experiments, an ensemble of 10 small neural networks is first trained by randomly initialising each member. Thereafter, the ensemble is distilled using KD, EDD and Laplace L-EDD. We perform a qualitative comparison of these methods by displaying both the loss surface (see Figures \ref{fig:toy-dir-loss}-\ref{fig:toy-lap-loss}) of each approach and the resulting confidence (maximum softmax probability) contours ((see Figures \ref{fig:toy-ens-conf}-\ref{fig:toy-lap-conf})). 
%
The loss surface shows how the student distillation loss varies over the input space, thus providing useful information about which regions of the data are successfully optimised. 
%
The confidence contours are useful to understand if the system can separate between each of the three classes. EDD training had to be terminated early as it diverged due to large gradients originating from the high-confidence regions (as discussed in \citet{ryabin}).

Figure \ref{fig:toy-ens-conf} shows the ensemble confidence contours which clearly trace out class boundaries and partially separate the three classes. The confidence also increases as one moves further away from regions of overlap since there is less uncertainty. This is the behaviour we expect from a properly trained system and further, expect distilled students to behave similarly. Next, we knowledge distil the ensemble onto a single student model, see Figures \ref{fig:toy-dist-loss} and \ref{fig:toy-dist-conf}. The loss surface shows that the student can optimise the distillation objective over regions with high overlap well and generate confidence score contours that are consistent with the teacher ensemble. 

However, Ensemble Distribution Distillation completely fails on this very simple task. Observing the loss surface in Figure \ref{fig:toy-dir-loss}, one can infer that the Dirichlet student is unable to optimise regions where there is significant data overlap, instead displaying extremely small losses in regions for which the teacher ensemble is already confident. This links back to a result in \citet{ryabin} in which they find that highly confident teachers can induce extremely large gradients. This also translates into inaccurate confidence contours which are unable to separate between classes, especially in regions of overlap, see Figure \ref{fig:toy-dir-conf}. 

The Laplace L-EDD approach circumvents these issues by operating in the logit space. The resulting loss surface is much more consistent with the knowledge-distilled student, optimising regions of high overlap, see Figure \ref{fig:toy-lap-loss}. Similarly, the confidence contours trace out boundaries consistent with the ensemble but with lower overall confidence. Since the log-likelihood (similar to KL-divergence) is a mode-covering objective \citep{minka2005divergence}, the Laplace student ends up predicting distributions that overestimate the range of ensemble logits. Couple this with the long tails of the distributions and the Laplace will often overestimate the variance in logits and produce lower overall confidence. 


\section{Machine Translation}
\label{sec:machine-translation}

This section reports on the performance of base transformers \citep{aiayn} trained on the En-De WMT'16 dataset, consisting of 4.5 million sentence pairs covering topics such as news \& policy.
%
We use newstest-13 for validation and newstest-14 for predictive evaluation.
%
For the main task of investigating out-of-distribution (OOD) detection, we compare the in-distribution (ID) newstest-14 with one of the publically available Khresmoi-Summary (Khresmoi) \citep{khresmoi}, MTNT \citep{mtnt} and Kyoto Free Translation Task (KFTT) \citep{kftt} datasets. These datasets relate to medical articles, Reddit-based noisy conversational text and specialised Wikipedia articles, respectively. 
%
Furthermore, we apply insights from training these systems to big transformers \citep{aiayn} trained on the larger En-Ru WMT'20 consisting of 58 million pairs after processing. In this case, we use newstest-19 for validation and newstest-20 for evaluation. The OOD detection task uses newstest-20 as ID and the same OOD datasets as above for the base transformer.

Data is tokenized using Moses, following \cite{scaling-nmt}. For WMT'16, a shared dictionary is trained using Byte Pair Encoding (BPE) with 32,000 merge operations \citep{bpe}. For WMT'20 we learn disjoint dictionaries using BPE with 40,000 merge operations.
%
The predictive performance (translation quality) will be evaluated using corpus-level (Sacre)BLEU \citep{sacrebleu}, with no post-processing of outputs before being scored. For the main task of detection, we use the ubiquitous threshold-independent AUROC metric \citep{auroc}, with baseline random detection corresponding to a score of 50\%.

All standard transformers are trained using an inverse square root with a linear warmup stage. A stronger Deep Ensemble baseline is formed by taking $M = 5$ such models. To avoid the high training cost of Deep Ensembles, we also train Snapshot Ensembles \citep{temporalensemble, snapshot} with a cyclic learning rate \citep{cyclic} to showcase that distribution distillation can be achievable with smaller training budgets. 
%
Since building Deep Ensembles is expensive, and the performance difference to Snapshot Ensembles was shown to be small, we opted to perform most distillation experiments on Snapshot Ensembles. We compare the proposed L-EDD approaches with KD and EDD and repeat each experiment 5 times, each with a different Snapshot Ensemble. 
%
Finally, similar to a range of prior work on uncertainty estimation tasks \citep{endd, structured, tcp}, we do not aim to achieve state-of-the-art predictive performance but opt for a simpler setup with a focus on achieving better uncertainty estimation. All setup details are provided in the supplementary appendix. Hyperparameters were determined on ID validation sets.

\begin{table}[h!]
	\centering
	\begin{minipage}[t]{0.48\textwidth}%
		\begin{center}
                \caption{Model parameter size, relative training time and translation performance on newstest-14 $\pm$2std (BLEU) for base transformer. We include two different KD baselines, one for each teacher ensemble. 
                % Snapshot Ensemble distillation approaches achieve competitive BLEU scores at significantly reduced training and parameter costs compared to the Deep Ensemble.
                }

			\def\arraystretch{1.00}
			\makebox[1.0\textwidth][c]{
                    \small
				\begin{tabular}{l|cc|l}
					\toprule
					\multirow{2}{*}{\textbf{Model}} & 
					\multirow{2}{*}{\textbf{Size}} & 
					\textbf{Train} & 
					\multirow{2}{*}{\textbf{BLEU} $\uparrow$} \\ 
					& & \textbf{Time} & \\
					\midrule
					{{Standard}} & 60.9M & 1.0 & 25.85 \tpm 0.17 \\
					% {Self distillation} & 60.9M & 1.8 & 26.48 \tpm 0.15 \\
					% \midrule
					% {Self-distribution dist.} & 60.9M & 1.7 & 25.89 \tpm 0.20 \\
					\midrule
					{{Deep Ensemble}} & 304.5M & 5.0 & 26.72 \\
					{KD (Categorical)} & 60.9M & 5.9 & 26.70 \tpm 0.26 \\
					\midrule
					{{Snapshot Ensemble}} & 304.5M  & 1.5 & 26.54 \tpm 0.16 \\
					{KD (Categorical)} & 60.9M & 1.9 & {27.02} \tpm 0.19 \\
					{EDD (Dirichlet)} & 60.9M & 2.0 & {26.96} \tpm 0.06 \\
					{L-EDD (Gaussian)} & 77.8M & 1.9 & {26.90} \tpm 0.28 \\
					{L-EDD (Laplace)} & 77.8M & 1.9 & {27.08} \tpm 0.20 \\
					\bottomrule
			\end{tabular}}
			\label{tab:wmt16}
		\end{center}
	\end{minipage}
\end{table}

\begin{table*}[b!]
	\centering{}
	\begin{minipage}[t]{1.0\textwidth}%
		\begin{center}
                \caption{OOD detection performance (\%{AUROC} $\uparrow$ $\pm$ 2 std) for base transformer with ID dataset newtest-14 and OOD datasets Khresmoi, MTNT and KFTT. \textbf{Bold} indicates best in a column, \underline{underline} second best. 
                % We include two different KD baselines, one for each ensemble.
                Laplace L-EDD with knowledge uncertainty (KU), shows superior performance for all OOD datasets even compared to the Deep Ensemble.
                }
	
			\def\arraystretch{1.00}
			\small
			\begin{tabular}{l|ll|ll|ll}
				\toprule
				\multirow{2}{*}{\textbf{Model}} & 
				\multicolumn{2}{c|}{\textbf{Khresmoi}} & 
				\multicolumn{2}{c|}{\textbf{MTNT}} & 
				\multicolumn{2}{c}{\textbf{KFTT}}  \\
				& \multicolumn{1}{c}{\textbf{TU}} 
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c}{\textbf{KU}}  \\
				\midrule
                    {Standard} & 47.5 \tpm 0.8 & \xmark & 63.5 \tpm 1.3 & \xmark & 30.6 \tpm 1.2 & \xmark \\
                    \midrule
				% {Self-distribution dist.}    & 48.7 \tpm 2.8 & 54.4 \tpm 3.3 & 63.8 \tpm 1.9 & 58.9 \tpm 2.0 & 31.3 \tpm 2.3 & 31.4 \tpm 3.1 \\
				{Deep Ensemble} & 48.0 & 61.9 & 64.5 & 63.7 & 30.1 & 44.0 \\
                    {KD (Categorical)} & 47.9 \tpm 1.1 & \xmark & 64.5 \tpm 1.3 & \xmark & 29.8 \tpm 0.7 & \xmark \\
				\midrule
				Snapshot Ensemble  & 49.0 \tpm 0.6 & 62.6 \tpm 1.1 & 63.8 \tpm 1.2 & 63.1 \tpm 0.7 & 31.7 \tpm 0.9 & \underline{47.4} \tpm 2.5 \\
				{KD (Categorical)}  & 48.0 \tpm 1.4 & \xmark & 64.6 \tpm 0.9 & \xmark & 31.3 \tpm 0.5 & \xmark \\
				{EDD (Dirichlet)}   & 49.6 \tpm 1.3 & 57.1 \tpm 1.4 & \underline{65.1} \tpm 1.7 & \underline{65.6} \tpm 2.0 & 31.0 \tpm 0.9 & 36.2 \tpm 1.4 \\
				{L-EDD (Gaussian)}  & \underline{59.5} \tpm 1.1 & \underline{71.7} \tpm 1.9 & \textbf{66.3} \tpm 1.6 & 64.0 \tpm 2.1 & \underline{35.8} \tpm 1.2 & 44.0 \tpm 0.2 \\
				{L-EDD (Laplace)}   & \textbf{65.1} \tpm 1.8 & \textbf{73.1} \tpm 1.7 & \underline{65.1} \tpm 1.5 & \textbf{66.8} \tpm 1.8 & \textbf{37.8} \tpm 0.2 & \textbf{48.8} \tpm 1.4 \\
				\bottomrule
			\end{tabular}
			\label{tab:wmt16-detection}
		\end{center}
	\end{minipage}
\end{table*} 

\subsection{Base Transformer Results}
\label{sec:res-base}

Table \ref{tab:wmt16} shows both the efficiency and performance of a wide range of systems on newstest-14. As expected the performance of the Deep Ensemble trumps both the Snapshot Ensemble and a standard trained system. Surprisingly, Snapshot Ensemble distilled students achieve better performance, a pattern also observed in self-distilled systems and is explored in more detail in \citet{selfmicro}. 

Next, we compare the threshold-independent out-of-distribution detection performance of baseline systems with L-EDD models.
%
From Table \ref{tab:wmt16-detection}, we observe that Snapshot Ensembles are able to compete with the Deep equivalent whilst being more than 3 times cheaper to train.
%
Furthermore, the knowledge-distilled students are able to match the detection performance of their Deep and Snapshot ensemble teachers using total uncertainty (TU). This is a natural result since they were specifically designed to capture the predictive distribution of their teacher ensemble. However, since KD students are unable to estimate knowledge uncertainty (KU), they fail to reach ensemble-level detection performance in all but the MTNT dataset.
%
Similarly, the modified Dirichlet baseline as described by \citet{ryabin} is able to achieve similar detection performance using TU but with the added ability to estimate KU. And whilst the Dirichlet KU are often better than its TU estimates, they often fall short when compared to ensembles.
%

On the other hand, the Laplace \& Gaussian L-EDD models are (surprisingly) able to outperform both ensembles in all three detection splits, producing either similar or significantly better TU and KU estimates. This may partially be due to the fact that diagonal Laplace and Gaussian distributions have more parameters and are more flexible than the Dirichlet, and also because they do not suffer from the same optimisation issues.
% and can thus better extract `dark knowledge' from the ensemble \citep{kd}. 
Nonetheless, neither reason explains why L-EDD models can outperform ensembles in detection. We explore this pattern in Section \ref{sec:analysis}.
%

%
Additionally, many models are worse than a random detector, especially for the KFTT and partially for the Khresmoi dataset. A partial explanation could be that these datasets contain longer sequences. When decoding, the transformer models produce more and more confident predictions further along in the output sequence, causing lower uncertainty scores. Section \ref{sec:analysis} investigates this effect and isolates a possible reason behind Laplace's success in outperforming ensembles.

As an aside, we observe that estimates of knowledge uncertainty are clearly important for OOD detection for autoregressive sequence tasks (and this is corroborated in prior work \cite{structured, ryabin}). This is in contrast to recent empirical results on image classification data, which show the opposite, that measures of knowledge uncertainty are not useful for indicating distributional shifts \citep{Xia2022OnTU, abe2022deep}. 
% this confirmation bias effect could be resolved by properly calibrating the models.
% but analysing the over-confidence and calibration of autoregressive sequence models is an under-explored area and is therefore, out of the scope of this paper.

\begin{table}[h!]
	\centering
	\begin{minipage}[t]{0.48\textwidth}%
		\begin{center}
                \caption{Model parameter size, relative training time and translation performance on newstest-20 $\pm$2std (BLEU) for the big transformer. 
                % Similar to Table \ref{tab:wmt16}, Snapshot Ensemble distillation achieves competitive BLEU scores at reduced training and parameter costs.
                }
        
			\def\arraystretch{1.00}
			\makebox[1.0\textwidth][c]{
                    \small
				\begin{tabular}{l|cc|l}
					\toprule
					\multirow{2}{*}{\textbf{Model}} & 
					\multirow{2}{*}{\textbf{Size}} & 
					\textbf{Train} & 
					\multirow{2}{*}{\textbf{BLEU} $\uparrow$} \\ 
					& & \textbf{Time} & \\
					\midrule
					{{Standard}} & 271M & 1.0 & 26.28 \tpm 0.34 \\
					\midrule
					{{Deep Ensemble}} & 1.35B & 5.0 & 26.81 \\
					\midrule
					{{Snapshot Ensemble}} & 1.35B  & 1.5 & 26.42 \tpm 0.23 \\
                    {KD (Categorical)} & 271M & 2.1 & 26.73 \tpm 0.16 \\
					{EDD (Dirichlet)} & 271M & 2.2 & 26.66 \tpm 0.19 \\
                    {L-EDD (Laplace)} & 320M & 2.2 & 26.71 \tpm 0.18 \\
					\bottomrule
			\end{tabular}}
			\label{tab:wmt20}
		\end{center}
	\end{minipage}
\end{table}

\begin{table*}[b!]
	\centering{}
	\begin{minipage}[t]{1.0\textwidth}%
		\begin{center}
			\caption{OOD detection performance (\%{AUROC} $\uparrow$ $\pm$ 2 std) for big transformer with ID dataset newtest-20 and OOD datasets Khresmoi, MTNT and KFTT. \textbf{Bold} indicates best in a column, \underline{underline} second best. Similar to Table \ref{tab:wmt20-detection}, L-EDD (Laplace) with KU shows superior performance over all OOD datasets.}

			\def\arraystretch{1.00}
			\small
			\begin{tabular}{l|ll|ll|ll}
				\toprule
				\multirow{2}{*}{\textbf{Model}} & 
				\multicolumn{2}{c|}{\textbf{Khresmoi}} & 
				\multicolumn{2}{c|}{\textbf{MTNT}} & 
				\multicolumn{2}{c}{\textbf{KFTT}}  \\
				& \multicolumn{1}{c}{\textbf{TU}} 
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c}{\textbf{KU}}  \\
				\midrule
				{Deep Ensemble} & 39.3 & 53.2 & 70.8 & 69.0 & 51.0 & 60.3 \\
				\midrule
                    Snapshot Ensemble  & \underline{40.8} \tpm 0.5 & \underline{55.0} \tpm 0.8 & 70.1 \tpm 0.5 & \underline{69.3} \tpm 0.9 & \underline{51.1} \tpm 0.6 & \underline{60.9} \tpm 1.4 \\
                    {KD (Categorical)}		   & 40.4 \tpm 0.8 & \xmark & \underline{70.9} \tpm 1.0 & \xmark & 50.9 \tpm 0.6 & \xmark \\
                    {EDD (Dirichlet)} & 41.0 \tpm 0.8 & 52.9 \tpm 1.3 & 71.3 \tpm 0.7 & 69.4 \tpm 0.7 & 51.0 \tpm 0.8 & 60.0 \tpm 1.3 \\
                    {L-EDD (Laplace)}	   & \textbf{51.0} \tpm 0.9 & \textbf{63.4} \tpm 1.2 & \textbf{72.6} \tpm 0.8 & \textbf{70.2} \tpm 0.6 & \textbf{63.2} \tpm 1.0 & \textbf{70.2} \tpm 1.1 \\
				\bottomrule
			\end{tabular}
			\label{tab:wmt20-detection}
		\end{center}
	\end{minipage}
\end{table*} 

\subsection{Big Transformer Results}
\label{sec:res-big}

In this section, we take the best-performing systems from the previous section and apply them to the `big transformer' on the larger En-Ru WMT'20 dataset. Table \ref{tab:wmt20} shows the efficiency and predictive performance on newstest-20. Again, we observe that the Deep Ensemble outperforms its Snapshot equivalent. Furthermore, the KD and L-EDD students, distilled from the Snapshot Ensemble were able to outperform their teacher. However, unlike the smaller-scale experiments, these students were only able to reach Deep Ensemble performance within a standard deviation, but were able to do so with a single forward pass.

From Table \ref{tab:wmt20-detection} we observe a similar pattern in which Deep and Snapshot ensembles perform equivalently whilst L-EDD (Laplace) is able to significantly outperform both ensembles in all but the MTNT dataset. Interestingly, unlike in Table \ref{tab:wmt16-detection} where no model was able to beat a random detector on the KFTT detection, the larger En-Ru WMT'20 based models are able to differentiate between newstest-20 and KFTT; switching the ID dataset to newstest-14 does not affect the results notably.


\section{Analysis: Ensemble vs Laplace}
\label{sec:analysis}

\subsection{Augmented Ensemble Uncertainties}

Both Sections \ref{sec:res-base} and \ref{sec:res-big} found that L-EDD models overall significantly outperformed their teacher Snapshot Ensemble and a Deep Ensemble. Therefore, we propose an alternative experiment to understand the source of L-EDD's superior performance. We fit an auxiliary Laplace distribution to a Deep Ensemble during inference and use the samples from this proxy to perform the detection task.

Consider a Deep Ensemble which produces a set of normalised logits $\{\tilde{\bm z}_{l}^{(m)}\}_{m=1}^M$ as in Equation (\ref{eq:logits}). In traditional uncertainty estimation, these logits would be transformed into categorical distributions. However, in this experiment, we estimate an auxiliary Laplace distribution using maximum likelihood (which is the loss-minimising distribution for a Laplace L-EDD student):
\begin{align}
    \tilde{\bm\mu}_l, \tilde{\bm\sigma}_l = \argmax_{\bm\mu, \bm\sigma} \sum_{m} \ln{\tt Lap}(\tilde{\bm z}_l^{(m)} \vert \bm\mu, \bm\sigma).
\end{align}
By sampling new points from this auxiliary distribution, we can estimate total and knowledge uncertainty:
\begin{align}
    \tilde{\bm\pi} = \Softmax(\bm z), \thickspace \bm z \sim {\tt Lap}(\tilde{\bm\mu}_l, \tilde{\bm\sigma}_l).
\end{align}
The aim of this modified approach to ensemble-based uncertainty estimation is to investigate whether or not approximating the logits with a Laplace distribution is the reason behind L-EDD performing better. 

\begin{table*}[t!]
	\centering{}
	\begin{minipage}[t]{1.0\textwidth}%
		\begin{center}
			\caption{OOD detection performance (\%{AUROC} $\uparrow$ $\pm$ 2 std) following the same setup as in Table \ref{tab:wmt16-detection}. The Laplace augmented ensemble demonstrates much better performance in most cases compared to its standard counterpart.}
	
			\def\arraystretch{1.00}
			\small
			\begin{tabular}{l|ll|ll|ll}
				\toprule
				\multirow{2}{*}{\textbf{Model}} & 
				\multicolumn{2}{c|}{\textbf{Khresmoi}} & 
				\multicolumn{2}{c|}{\textbf{MTNT}} & 
				\multicolumn{2}{c}{\textbf{KFTT}}  \\
				& \multicolumn{1}{c}{\textbf{TU}} 
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c|}{\textbf{KU}}  
				& \multicolumn{1}{c}{\textbf{TU}}  
				& \multicolumn{1}{c}{\textbf{KU}}  \\
				\midrule
                    {Deep Ensemble}           & 48.0 & 61.9 & \underline{64.5} & \underline{63.7} & 30.1 & 44.0 \\
                    {Deep Ensemble (Laplace)} & \underline{62.5} & \underline{72.1} & 63.8 & 56.1 & \underline{34.8} & \textbf{55.4} \\
				\midrule
                    {L-EDD (Laplace)}         & \textbf{65.1} \tpm 1.8 & \textbf{73.1} \tpm 1.7 & \textbf{65.1} \tpm 1.5 & \textbf{66.8} \tpm 1.8 & \textbf{37.8} \tpm 0.2 & \underline{48.8} \tpm 1.4 \\
				\bottomrule
			\end{tabular}
			\label{tab:wmt16-detection-laplace}
		\end{center}
	\end{minipage}
\end{table*} 

Table \ref{tab:wmt16-detection-laplace} shows the detection performance of the Laplace-modified Deep Ensemble, following the detection setup in Section \ref{sec:res-base}. 
%
Clearly, the augmented ensemble bridges the OOD detection performance gap between standard Deep Ensemble and L-EDD.
%
This suggests that measures of uncertainty based on directly fit, logit-space models of an ensemble are better at indicating distributional shift \textit{than directly using the ensemble logits}, for autoregressive sequence prediction. We remark that fully understanding why this is the case would be an interesting direction of future research, as it may enable further advancement in autoregressive out-of-distribution detection.
% This suggests that the Laplace distribution in logit space attends to different aspects of the ensemble compared to standard model combination. 
%
% A deeper inspection of the uncertainties shows that the augmented ensemble consistently predicts higher uncertainties, meaning that the Laplace weighs the low-probability classes higher, an effect often observed when calibrating models. However, unlike standard calibration methods \citep{calibration} which perform calibration based on some predefined static hyperparameters, the Laplace does so dynamically for each input and token in the sequence.

% Additionally, the Laplace L-EDD captures the best performance of both ensembles in all but one case, using KU to detect KFTT. This could be attributed to the chosen loss function, which combines KD and L-EDD.

\subsection{Model Confidence for Longer Sequences}

Another reason for Laplace L-EDD's superior performance could be found in analysing behaviour with increasing sequence lengths.
%
\begin{figure*}[b!]
     \centering
     \begin{subfigure}[b]{0.23\textwidth}
         \centering
         \includegraphics[width=\textwidth]{archive/length-vs-total-ensemble-id.pdf}
         
         \caption{PCC = -20\%}
       
     \end{subfigure}
     \hfill
     \begin{subfigure}[b]{0.23\textwidth}
         \centering
         \includegraphics[width=\textwidth]{archive/length-vs-total-ensemble-ood-khresmoi-summary.test.pdf}
         
         \caption{PCC = -29\%}
     
     \end{subfigure}
     \hfill
     \begin{subfigure}[b]{0.23\textwidth}
         \centering
         \includegraphics[width=\textwidth]{archive/length-vs-total-laplace-id.pdf}
        
         \caption{PCC = -3\%}
       
     \end{subfigure}
     \hfill
     \begin{subfigure}[b]{0.23\textwidth}
         \centering
         \includegraphics[width=\textwidth]{archive/length-vs-total-laplace-ood-khresmoi-summary.test.pdf}
    
         \caption{PCC = 1\%}
    
     \end{subfigure}

     % \caption{Length vs uncertainty density plots for in and out-of-domain datasets. The top row corresponds to total uncertainty and the bottom row to knowledge uncertainty. The left half corresponds to the Deep Ensemble and the right half to Laplace L-EDD. Each figure caption also shows the person correlation coefficient (PCC) between the uncertainty metric and sequence length.}
     \caption{Length vs uncertainty density plots for ID (newstest-14) and OOD (Khresmoi) datasets. The left half corresponds to the Deep Ensemble and the right half to Laplace L-EDD. Each figure caption also shows the Pearson Correlation Coefficient (PCC) between sequence length and total uncertainty. Total uncertainty and length are negatively correlated for the ensemble, i.e. it is more confident on longer sequences, but they are uncorrelated for L-EDD.}
     \label{fig:confirmation-bias}
\end{figure*}
%
Figure \ref{fig:confirmation-bias} shows how Deep Ensemble and L-EDD total uncertainties scale with output sequence length for both ID (newstest-14) and OOD (Khresmoi) datasets. Under each figure is also the associated Pearson correlation.

Three observations can be made from this data. The first is that the L-EDD system consistently outputs much higher total uncertainties. The second is that the Deep Ensemble displays a negative correlation between total uncertainty and sequence length. This implies that the ensemble is more confident in translating longer sequences, but this is also why it fails in detecting OOD datasets which contain longer sequences.
%
The third and possibly most significant observation is that the Laplace system shows almost no correlation between total uncertainty and sequence length for both ID and OOD datasets, effectively eliminating the length bias, and allowing it to better differentiate between ID and OOD datasets even when the OOD inputs differ in length from what the detection system was trained on. 



\section{Conclusion}
\label{sec:conclusion}
In this work, we investigate the efficient estimation of uncertainties for large-scale autoregressive sequence prediction. To this end, we examine Ensemble Distribution Distillation (EDD) in the \textit{logit}-space, in order to bypass optimisation issues found in softmax-space EDD.
We perform experiments using modern transformer models trained to perform large-scale machine translation. They show that a student model trained to parameterise a \textit{Laplace} distribution over logits is able to significantly outperform Deep Ensembles for OOD detection at a fraction of the inference cost, whilst matching the ensemble for translation quality. 
% Further analysis reveals that merely fitting a Laplace distribution post-hoc to an ensemble can boost OOD detection, and that Laplace models are able to decorrelate input sequence from confidence, preventing confirmation bias.
Moreover, we show that the use of Snapshot Ensembling can greatly reduce the training costs of EDD, without sacrificing translation performance.
We hope that our work can encourage further investigation into the comparatively less well-explored domain of uncertainty estimation for structured sequence prediction, on tasks such as machine translation, image captioning, and automatic speech recognition. 
\begin{acknowledgements} % will be removed in pdf for initial submission,
						 % (without ‘accepted’ option in \documentclass)
                         % so you can already fill it to test with the
                         % ‘accepted’ class option
    
Guoxuan Xia is funded jointly by Arm ltd. and EPSRC.
\end{acknowledgements}
% References

\bibliography{fathullah_460}
\end{document}
