\documentclass[accepted]{uai2022} % after acceptance, for a revised
                                    % version; also before submission to
                                    % see how the non-anonymous paper
                                    % would look like
%% There is a class option to choose the math font
% \documentclass[mathfont=ptmx]{uai2022} % ptmx math instead of Computer
                                         % Modern (has noticable issues)
% \documentclass[mathfont=newtx]{uai2022} % newtx fonts (improves upon
                                          % ptmx; less tested, no support)
% NOTE: Only keep *one* line above as appropriate, as it will be replaced
%       automatically for papers to be published. Do not make any other
%       change above this note for an accepted version.

%% Choose your variant of English; be consistent
\usepackage[american]{babel}
% \usepackage[british]{babel}

%% Some suggested packages, as needed:
\usepackage{natbib} % has a nice set of citation styles and commands
    \bibliographystyle{plainnat}
    \renewcommand{\bibsection}{\subsubsection*{References}}
\usepackage{mathtools} % amsmath with fixes and additions
% \usepackage{siunitx} % for proper typesetting of numbers and units
%\usepackage{booktabs} % commands to create good-looking tables
%\usepackage{tikz} % nice language for creating drawings and diagrams

%% custom packages
\usepackage[utf8]{inputenc} % allow utf-8 input
\usepackage[T1]{fontenc}    % use 8-bit T1 fonts
\usepackage{hyperref}       % hyperlinks                            
\usepackage{url}            % simple URL typesetting
\usepackage{booktabs}       % professional-quality tables
\usepackage{amsfonts}       % blackboard math symbols
\usepackage{nicefrac}       % compact symbols for 1/2, etc.
\usepackage{microtype}      % microtypography
\usepackage{xcolor}         % colors

% packages we used last year
\usepackage{subfig}      % subfigures
\usepackage{amsthm}         % theorems
\usepackage{amsmath}        % align
\usepackage{amssymb}
\usepackage[capitalise]{cleveref}       % cref
\usepackage{graphicx}       % images
\usepackage{wrapfig}        % wrapping image
\usepackage{makecell}
\usepackage{multirow}
\usepackage{floatrow}
\usepackage{authblk}
\floatsetup[figure]{style=plain,subcapbesideposition=center}
\newcommand\CoAuthorMark{\footnotemark[\arabic{footnote}]}

%% Provided macros
% \smaller: Because the class footnote size is essentially LaTeX's \small,
%           redefining \footnotesize, we provide the original \footnotesize
%           using this macro.
%           (Use only sparingly, e.g., in drawings, as it is quite small.)

%% Self-defined macros
\newtheorem{defi}{Definition}
\newtheorem{thm}{Theorem}
\newtheorem{lem}[thm]{Lemma}
\newtheorem{col}[thm]{Corollary}
\newtheorem*{nothm}{Theorem}
\newtheorem*{nolem}{Lemma}
\newtheorem*{nocol}{Corollary}

\title{Combating the Instability of Mutual Information-based Losses\\via Regularization}

% The standard author block has changed for UAI 2022 to provide
% more space for long author lists and allow for complex affiliations
%
% All author information is automatically removed by the class for the
% anonymous submission version of your paper, so you can already add your
% information below.
%
% Add authors
\author[1]{\href{mailto:<juice500@sogang.ac.kr>}{Kwanghee Choi\thanks{These authors contributed equally to this work.}}}
\author[2]{\href{mailto:<siyeong.lee@naverlabs.com>}{Siyeong Lee\protect\CoAuthorMark}}
% Add affiliations after the authors
\affil[1]{%
Sogang University
}
\affil[2]{%
NAVER LABS
}
\begin{document}
\maketitle

\begin{abstract} \label{sec:abstract}
Notable progress has been made in numerous fields of machine learning based on neural network-driven mutual information (MI) bounds.
However, utilizing the conventional MI-based losses is often challenging due to their practical and mathematical limitations.
In this work, we first identify the symptoms behind their instability: (1) the neural network not converging even after the loss seemed to converge, and (2) saturating neural network outputs causing the loss to diverge.
We mitigate both issues by adding a novel regularization term to the existing losses.
We theoretically and experimentally demonstrate that added regularization stabilizes training.
Finally, we present a novel benchmark that evaluates MI-based losses on both the MI estimation power and its capability on the downstream tasks, closely following the pre-existing supervised and contrastive learning settings.
We evaluate six different MI-based losses and their regularized counterparts on multiple benchmarks to show that our approach is simple yet effective.
\end{abstract}

\section{Introduction} \label{sec:intro}
Identifying a relationship between two variables of interest is one of the key problems in mathematics, statistics, and machine learning \citep{goodfellow2014generative, ren2015faster, he2016deep, vaswani2017attention}.
One of the fundamental approaches is information theory-based measurement, namely the measure of mutual information (MI).
Due to its mathematical soundness and the rise of deep learning, many have designed differentiable MI-based losses for neural networks.
Some utilize the MI-based losses to bridge the gap between latent variables and representations in generative adversarial networks \citep{nowozin2016f, chen2016infogan, belghazi2018mutual, oord2018representation, hjelm2019learning}, where others introduce MI-based methodologies identifying the relationship between input, output, and hidden variables \citep{tishby2015deep, shwartz2017opening, saxe2019information}.
Furthermore, recent self-supervised losses use contrastive losses, where its origin can be traced back to MI-based losses \citep{pmlr-v119-cheng20b, DBLP:conf/icml/Henaff20, DBLP:conf/nips/ChuangRL0J20}.

Although many have shown computational tractability and usefulness of MI-based losses, others still struggle with their instability during optimization.
Contrastive learning literature with MI-based losses such as \cite{chen2020simple, he2020momentum} use huge batch sizes to reduce the variance of losses.
\cite{DBLP:journals/corr/abs-2105-04906} adds a regularization term to the neural network embeddings to stabilize the training.
\cite{mcallester2018formal} and \cite{Song2020Understanding} further provide theoretical limitations of variational MI estimators, arguing that the limited batch size induces a MI estimation variance too large to handle.
We argue that mitigating the variance of MI-based losses is critical for stabilizing training, where it is well known that more stable optimization of neural networks yields better predictive performance on the downstream tasks \citep{DBLP:conf/iclr/RothfussLCAA19, DBLP:conf/cogsci/BearC20, DBLP:conf/nips/ChavdarovaGFL19, DBLP:conf/nips/RichterBNRA20, DBLP:journals/jr/ZengWW20, DBLP:conf/acl/ColomboPC20}.

In this paper, we concentrate on identifying the cause behind the instability of MI-based losses and propose a simple yet effective regularization method that can be applied to various MI-based losses.
We start by analyzing the behaviors of two MI estimators; the MI Neural Estimator (MINE) loss \citep{belghazi2018mutual} and Nguyen-Wainwright-Jordan loss (NWJ) loss \citep{nguyen2010estimating}.
We identify two distinctive behaviors that induce instability during training, drifting and exploding neural network outputs.
Based on these observations, we design two novel dual representations of the KL-divergence called Regularized Donsker-Varadhan representation (ReDV) and Regularized NWJ representation (ReNWJ).
We show theoretically and experimentally that adding our regularizer term suppresses two behaviors of drifting and exploding, avoiding instability during training.
Finally, we design a novel benchmark that bridges the gap between variational MI estimators and real-world tasks, whereas previous works either do not directly show the MI estimation performance or evaluate only on toy problems.
We reformulate both the supervised and the contrastive learning problem \citep{chen2020simple, he2020momentum, khosla2020supervised} as MI estimation problems and show that our regularization yields better performance on both perspectives, downstream task and MI estimation performance.

\section{Background \& Related works} \label{sec:related_works}
\paragraph{Definition of MI} The mutual information between two random variables $X$ and $Y$ is defined as
\begin{align}\begin{split}
\label{MIDef}
    I(X,Y) & = D_\text{KL}(\mathbb{P}_{X Y} || \mathbb{P}_X \otimes \mathbb{P}_Y) \\
     & = \mathbb{E}_{\mathbb{P}_{XY}} (\log{\frac{d\mathbb{P}_{XY}}{d\mathbb{P}_{X\otimes Y}}})
\end{split}\end{align}
where $\mathbb{P}_{X Y}$ and $\mathbb{P}_X \otimes \mathbb{P}_Y$ are the joint distribution and the product of the marginal distributions, respectively.
$D_\text{KL}$ is the Kullback-Leibler (KL) divergence.
Without loss of generality, we consider $\mathbb{P}_{X Y}$ and $\mathbb{P}_X \otimes \mathbb{P}_Y$ as being distributions on a compact domain $\Omega \subset \mathbb{R}^d$.

\paragraph{MI through dual representation of $D_\text{KL}$}
We first introduce two dual representations of $D_\text{KL}$, as MI is defined using $D_\text{KL}$.
The most widely known is the Donsker-Varadhan representation $D_\text{DV}$ \citep{donsker1975asymptotic}.
For given two distribution $\mathbb{P}$ and $\mathbb{Q}$ on some compact domain $\Omega \subset \mathbb{R}^d$,
\begin{equation}\label{DV}
    D_\text{DV}(X,Y) := \sup_{T:\Omega\to\mathbb{R}} \mathbb{E}_{\mathbb{P}}(T) - \log(\mathbb{E}_{\mathbb{Q}}(e^{T})),
\end{equation}
where both the expectations $\mathbb{E}_{\mathbb{P}}(T)$ and $\mathbb{E}_{\mathbb{Q}}(e^{T})$ are finite.
If we substitute $\mathbb{P}$ and $\mathbb{Q}$ into $\mathbb{P}_{X Y}$ and $\mathbb{P}_X \otimes \mathbb{P}_Y$, $D_\text{DV}$ yields the definition of MI.
The optimal $T^* = \log\frac{d\mathbb{P}}{d\mathbb{Q}} + C$, where $C \in \mathbb{R}$ can be any constant.

In contrast to $D_\text{DV}$, the Nguyen-Wainwright-Jordan representation $D_\text{NWJ}$ \citep{nguyen2010estimating} is induced by the convex conjugate known as Fenchel's inequality \citep{hiriart2004fundamentals}:
\begin{equation}\label{NWJ}
    D_\text{NWJ}(X,Y) := \sup_{T:\Omega\to\mathbb{R}} \mathbb{E}_{\mathbb{P}}(T) - \mathbb{E}_{\mathbb{Q}}(e^{T - 1})
\end{equation}
The optimal $T^* = \log\frac{d\mathbb{P}}{d\mathbb{Q}} + 1$ is unique unlike the optimal $T^*$ of $D_\text{DV}$ due to its self-normalizing property \citep{belghazi2018mutual}.
However, $D_\text{DV}$ guarantees tighter lower bounds than $D_\text{NWJ}$ \citep{Ruderman2012TighterVR, polyanskiy2014lecture}.
These two representations provide the theoretical soundness for numerous variational MI bounds.

\paragraph{Variational MI estimation}
With the increasing success of neural networks, several neural network-driven variational bounds of MI are proposed.
They are widely employed, such as contrastive learning \citep{oord2018representation, he2020momentum, chen2020simple} or generative adversarial training \citep{belghazi2018mutual, nowozin2016f}.
Variational bounds of MI commonly focus on estimating $T^*$ via a neural network $T_{\theta}:\Omega\to\mathbb{R}$, called the statistics network \citep{belghazi2018mutual}, which outputs a single real value given the input sample pairs.

$I_\text{MINE}$ \citep{belghazi2018mutual} directly maximize $D_\text{DV}$ as the objective function by feeding the samples $(x, y)$ of $\mathbb{P}_{X Y}$ and $\mathbb{P}_X \otimes \mathbb{P}_Y$ into $T_{\theta}$:
\begin{multline}\label{eq:MINE}
    I_{\text{MINE}}(X,Y) := \\ \mathbb{E}_{\mathbb{P}_{X Y}^{(n)}}(T_\theta(x, y)) - \log(\mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y}^{(n)}(e^{T_\theta(x, y)})),
\end{multline}
where $\mathbb{P}^{(n)}$ is the empirical distribution associated to $n$ i.i.d. samples for given distribution $\mathbb{P}$.
\cite{belghazi2018mutual} also utilizes moving averages of mini-batches to reduce the MI estimation variance caused by the limited batch size.

$I_\text{InfoNCE}$ \citep{oord2018representation} is also commonly used due to its stability and decent performance:
\begin{equation}
I_{\text{InfoNCE}}(X, Y) = \frac{1}{N}\sum_{i=1}^N\log\frac{e^{T_\theta(x_i, y_i)}}{\frac{1}{N}\sum_{j}^N e^{T_\theta(x_i, y_j)}}
\end{equation}
where the $N$ samples $(x_i, y_i)_{i=1}^N$ are drawn from $\mathbb{P}_{XY}$, which becomes equivalent to using the Softmax function with the negative log loss.
$I_\text{InfoNCE}$ is also equivalent to $I_\text{MINE}$ up to a constant, but upper bounded by $\log N$, hence not able to estimate large MI values \citep{oord2018representation}.

\cite{poole2019variational} introduced $I_\text{TUBA}$, a unified lower bound, by expanding $D_\text{NWJ}$ \citep{barber2004information,nguyen2010estimating}.
\begin{multline}\label{eq:NWJLOSS}
I_\text{NWJ}(X, Y) := \\ \mathbb{E}_{\mathbb{P}_{XY}^{(n)}}(T_{\theta}(x, y)) -  \mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y^{(n)}} (e^{T_{\theta}(x, y)-1}),
\end{multline}
\begin{multline}\label{eq:TUBA}
I_\text{TUBA}(X, Y) :=  \mathbb{E}_{\mathbb{P}_{XY}^{(n)}}(T_{\theta}(x, y)) \\ -  \mathbb{E}_{\mathbb{P}^{(n)}_{Y}}\left({\mathbb{E}_{\mathbb{P}^{(n)}_{X}}(e^{T_{\theta}(x, y)})}/{a(y)} + \log(a(y))- 1\right),
\end{multline}
where $a(y)$ is the variational parameter.
However, unlike $I_\text{MINE}$ or $I_\text{InfoNCE}$, directly using the exponential term often causes numerical instability.
Even if $T_\theta$ outputs a moderately sized value, $e^{T_\theta}$ can easily exceed the floating-point range.

To avoid this problem, \cite{poole2019variational} introduce $D_\text{NWJ}$-based lower bound $I_\text{JS}$ by using a softplus-activated neural network as $T_\theta$,
\begin{multline}\label{eq:JS}
I_\text{JS}(X, Y) :=  1 + \mathbb{E}_{\mathbb{P}_{XY}^{(n)}}(T_{\theta}(x, y)) \\ -  \mathbb{E}_{\mathbb{P}^{(n)}_{Y} \otimes \mathbb{P}^{(n)}_{X}}((e^{T_{\theta}(x, y)})).
\end{multline}

\paragraph{Variance problem of MI estimators}
Despite the variety of bounds proposed, many still suffer from the bias-variance trade-off \citep{poole2019variational}.
\cite{mcallester2018formal} and \cite{Song2020Understanding} prove that the $I_\text{MINE}$ estimator must have a batch size proportional to the exponential of true MI to control the variance of the estimation.

Many bounds try to mitigate this problem by reducing the variance of low-biased estimators, such as by interpolating with a low variance bound \citep{poole2019variational} or dropping the formal theoretical guarantees \citep{mcallester2018formal}.
 \cite{Song2020Understanding} proposed $I_\text{SMILE}$ to clip the range of $T_\theta$ trained with $I_\text{MINE}$, sacrificing the estimation quality to reduce the variance.
\begin{multline}\label{eq:SMILE}
    I_{\text{SMILE}}(X,Y) := \mathbb{E}_{\mathbb{P}_{X Y}^{(n)}}(T_\theta(x, y)) \\- \log(\mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \hat{\mathbb{P}}_Y^{(n)}}(\text{clip}(e^{T_\theta(x, y)}, e^{-\tau}, e^\tau)),
\end{multline}
where $\text{clip}(v, l, u) = \max(\min(v, u), l) $ for $v, u, l \in \mathbb{R}$.

\paragraph{Practical usages of MI}
MI-based losses are often applied in generative modeling, such as for better mode coverage \citep{belghazi2018mutual} or learning disentangled representations without supervision \citep{chen2016infogan, DBLP:conf/nips/OjhaSHL20, DBLP:journals/kbs/LiLNCC21, jeon2021ib}.
Representation learning employs MI-based losses \citep{DBLP:conf/nips/Tian0PKSI20, hjelm2019learning, DBLP:conf/iclr/TschannenDRGL20, pmlr-v119-cheng20b, DBLP:journals/corr/abs-2005-13149,DBLP:conf/iclr/WenZHZX20, DBLP:conf/eccv/BoudiafRZGPPA20, DBLP:conf/eccv/TianKI20, DBLP:journals/corr/abs-2005-04966} to yield feature extractors that reflect its downstream tasks well.
We emphasize that these approaches can be further utilized to measure the performance of MI estimators.

\paragraph{Comparing between MI estimators}
Toy datasets such as correlated multivariate Gaussian distributions has been widely accepted for the evaluation of MI estimation \citep{belghazi2018mutual, poole2019variational, Song2020Understanding, pmlr-v119-cheng20b, lin2019data}.
However, we emphasize that using synthetic data as a definitive benchmark will end up in a disparity with real-world tasks.
There have been some approaches that compared different MI estimators on generative modeling \citep{belghazi2018mutual, hjelm2019learning} or representational learning \citep{DBLP:conf/nips/Tian0PKSI20}.
However, finding the ideal MI for each downstream task is not trivial, making it impossible to directly assess the MI estimation quality.
Moreover, \cite{DBLP:conf/iclr/TschannenDRGL20} and \cite{DBLP:conf/nips/Tian0PKSI20} showed the gap between MI estimation quality and downstream performance on specific tasks.
Hence, it is crucial to evaluate both perspectives.
The closest work to our benchmark is the consistency test of \cite{Song2020Understanding} using CIFAR-10 \citep{krizhevsky2009learning} and MNIST \citep{lecun1998gradient}.
However, the test only offered to assess the ratio of two separate MI estimations, making it difficult to separately measure the quality of each estimation.

\section{Instability of MI Bounds} \label{sec:motivating_experiment}
To demonstrate and analyze the instability of variational MI bounds, we design a synthetic problem with the One-hot dataset.
We then solve the task via $I_\text{MINE}$ and $I_\text{NWJ}$, which are the losses derived from the two most commonly used representations of KL-divergence, $D_\text{DV}$ and $D_\text{NWJ}$, respectively.
Both losses consist of two terms, each derived from the statistics of joint distribution $\mathbb{E}_{\mathbb{P}_{XY}}$ and the product of marginal distributions $\mathbb{E}_{\mathbb{P}_{X} \otimes \mathbb{P}_Y}$.
Hence, to observe the behavior of each loss during training, we plot the two terms separately.
Also, to observe how each distribution differ by the statistics network outputs $T_\theta(x, y)$, we plot each output from $(x, y) \sim Supp(\mathbb{P}_{XY})$ and $(x, y) \sim Supp(\mathbb{P}_{X} \otimes \mathbb{P}_Y) \setminus Supp(\mathbb{P}_{XY})$, where we denote the support of $\mathbb{P}$ as $Supp(\mathbb{P})$.
Support is the set of values that the random variable can take \citep{def_support}.

\paragraph{One-hot Dataset}
We design a one-hot discrete problem with uniform distribution $X \sim U(1,N)$ to estimating $I(X, X)=\log N$ for a given integer $N$.
This task is intentionally created to easily discern samples $(x, x) \sim \mathbb{P}_{X X}$ from $(x, x) \sim \mathbb{P}_X \otimes \mathbb{P}_X$, so that we can directly observe its network outputs $T_\theta(x, x)$.

\begin{figure}[t]
\centering
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/mine-batch100-plot.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/nwj-batch100-plot.pdf}}
\caption{
    Training $T_\theta$ using $I_\text{MINE}$ and $I_\text{NWJ}$ with batch size $100$ for $3000$ iterations.
    We breakdown the MI loss into two components.
    We split $I_\text{MINE}$ into first term \textcolor{blue}{\textbf{$\mathbb{E}_{\mathbb{P}_{X X}}(T)$}} and second term \textcolor{orange}{\textbf{$\log\mathbb{E}_{\mathbb{P}_X \otimes \mathbb{P}_X}(e^T)$}}.
    Similarly, we split $I_\text{NWJ}$ into first term \textcolor{blue}{\textbf{$\mathbb{E}_{\mathbb{P}_{X X}}(T)$}} and second term \textcolor{orange}{\textbf{$\mathbb{E}_{\mathbb{P}_X \otimes \mathbb{P}_X}(e^{T-1})$}}.
}
\label{figs:onehot_toy_stable_loss}
\end{figure}

\begin{figure}[t]
\centering
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/mine-batch100-scatter.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/nwj-batch100-scatter.pdf}}
\caption{
    Training $T_\theta$ using $I_\text{MINE}$ and $I_\text{NWJ}$ with batch size $100$ for $3000$ iterations.
    We observe the statistics network outputs $T_\theta(x_1, x_2)$, where we split the outputs into two: \textcolor{blue}{$(x_1, x_2) \in Supp(\mathbb{P}_{X X})$} and \textcolor{orange}{$(x'_1, x'_2) \in Supp(\mathbb{P}_X \otimes \mathbb{P}_X) \setminus Supp(\mathbb{P}_{X X})$}.
}
\label{figs:onehot_toy_stable_scatter}
\end{figure}

\paragraph{Seemingly Stable Case}
We first observe the behaviors of the statistics network $T_\theta$ when the losses are seemingly stable, producing a successful MI estimate.
\cref{figs:onehot_toy_stable_loss} shows the MI estimates and the two terms that construct each MI estimate per batch.
We observe that the first and the second term estimates of $I_\text{MINE}$, unlike $I_\text{NWJ}$, drifting in parallel even after the MI estimate converge.
This is due to the free constant term $C$ in the optimal $T^*$ of $D_\text{DV}$, where the self-normalizing $D_\text{NWJ}$ avoids this problem.
This drifting phenomenon implies that $T_\theta$ is not stable even after the loss seems to be converged, as shown in \cref{figs:onehot_toy_stable_scatter}.
Also, the plot demonstrates how $T_\theta$ is trained; it isolates the samples $(x, y) \sim \mathbb{P}_{X Y}$ from the samples $(x, y) \sim \mathbb{P}_X \otimes \mathbb{P}_Y$.

\paragraph{Unstable Case}
We also demonstrate the behaviors of $T_\theta$ when the losses get unstable in \cref{figs:onehot_toy_unstable}.
We reduce the batch size to make the optimization unstable, where this behavior is often reported in multiple works \citep{oord2018representation, he2020momentum, chen2020simple}.
However, even though the losses seem unstable, $T_\theta$ successfully discerns the samples before the outputs explode.
We believe that this is because of how $T_\theta$ is optimized during training.
The statistics network outputs $T_\theta(x_1, x_2)$ of $(x_1, x_2) \in Supp(\mathbb{P}_{X X})$ gets increased by the first term but occasionally decreased by the second term.
However, $T_\theta(x'_1, x'_2)$ of $(x'_1, x'_2) \in Supp(\mathbb{P}_X \otimes \mathbb{P}_X) \setminus Supp(\mathbb{P}_{X X})$ gets decreased whatsoever, as $(x'_1, x'_2)$ is used only for the second term.
This makes the second term more unstable and motivates us to regularize it for better numerical stability during optimization.

\begin{figure}[t]
\centering
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/mine-batch32-plot.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/nwj-batch32-plot.pdf}}\\
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/mine-batch32-scatter.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/nwj-batch32-scatter.pdf}}
\caption{
    Training $T_\theta$ using $I_\text{MINE}$ and $I_\text{NWJ}$ with a reduced batch size of $32$ for $3000$ iterations.
    MI estimate diverges for both cases.
    Also, $I_\text{NWJ}$ incurs exploding $T_\theta$ outputs, hence the empty plot after ~23k iterations.
}
\label{figs:onehot_toy_unstable}
\end{figure}

To summarize, we suspect the instability of variational bounds comes from two reasons. 
Firstly, the statistics network did not converge even after the loss seemingly converged.
We argue that this is due to the unnormalized constant term in the optimal $T^*$ of $D_\text{DV}$, where $D_\text{NWJ}$ successfully avoids via self-normalization.
Secondly, the loss gets unstable as $T_\theta(x'_1, x'_2)$ endlessly decrease due to the second term.
This observation is also consistent with the theoretical findings of \cite{Song2020Understanding,mcallester2018formal}, where they show that large variance of the second term leads to failed MI estimation.
We claim that the outputs have to be regularized in some form to avoid the instability.

\begin{figure}[t]
\centering
\includegraphics[width=0.95\columnwidth]{asset/shifting_c.pdf}
\caption{
    Training $T_\theta$ with batch size $100$ for $1500$ iterations using $I_\text{ReMINE}$ with different $C^*$ (orange dotted line).
}
\label{figs:remine_different_c}
\end{figure}

\section{Stabilizing the MI Bounds} \label{sec:method}
In this section, we introduce two novel regularized representations and its corresponding losses to tackle the instability during optimization.
We show both theoretically and experimentally that adding regularization mitigates the unstable behavior of the statistics network $T_\theta$.
We also describe a simple windowing method that can sidestep the batch size limitation problem of the MI estimation problem.
We defer all the proofs to the Appendix.

\paragraph{Regularized representations}
We stabilize the two existing representations $D_\text{DV}$ and $D_\text{NWJ}$ by regularizing the second term.
We introduce two novel representations: Regularized DV ($D_\text{ReDV}$) and Regularized NWJ ($D_\text{ReNWJ}$),
\begin{align}\begin{split}
    D_\text{ReDV}(X,Y) := \sup_{T:\Omega \to \mathbb{R}} & \mathbb{E}_\mathbb{P}(T) - \log(\mathbb{E}_\mathbb{Q}(e^T)) \\ 
& - d(\log(\mathbb{E}_\mathbb{Q}(e^T)), C^*),
\end{split}\end{align}
\begin{align}\begin{split}
    D_\text{ReNWJ}(X,Y) := \sup_{T:\Omega \to \mathbb{R}} & \mathbb{E}_\mathbb{P}(T) - \mathbb{E}_\mathbb{Q}(e^{T-1})) \\ 
& - d(\mathbb{E}_\mathbb{Q}(e^{T-1})), 1),
\end{split}\end{align}
where $C^* \in \mathbb{R}$ is any constant and $d(*,*)$ is a distance function on $\mathbb{R}$.

\begin{thm} \label{thm:redv_renwj} $D_\text{ReDV}$ and $D_\text{ReNWJ}$ is a dual representation for $D_\text{KL}$ such that 
\begin{align}
D_\text{KL} (\mathbb{P} || \mathbb{Q}) &= D_\text{ReDV}(X, Y), \\
D_\text{KL} (\mathbb{P} || \mathbb{Q}) &= D_\text{ReNWJ}(X, Y).    
\end{align}
\end{thm}
We emphasize that both representations are not MI-specific but dual representations of $D_\text{KL}$, which can be easily extended to numerous variational MI bounds based on $D_\text{DV}$ and $D_\text{NWJ}$.
Especially, the newly added regularizer grants $D_\text{ReDV}$ the normalizing property, effectively solving the drifting problem of $D_\text{DV}$.

\paragraph{Regularizing $I_\text{MINE}$ and $I_\text{NWJ}$}
Based on $D_\text{ReDV}$ and $D_\text{ReNWJ}$, we propose a novel neural network-driven variational MI bound $I_\text{ReMINE}$ and $I_\text{ReNWJ}$ by choosing the Euclidean distance $d(x, y) = (x-y)^2$ and the log-Euclidean distance $d(x, y) = (\log x - \log y)^2$, respectively.
\begin{align}\begin{split}
    I_{\text{ReMINE}}(X, & Y) := \mathbb{E}_{\mathbb{P}_{X Y}^{(n)}}(T_\theta(x, y)) \\
     & - \log(\mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y^{(n)}}(e^{T_\theta(x, y)})) \\
     & - \lambda (\log(\mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y^{(n)}}(e^{T_\theta(x, y)})) - C^*)^2,
\end{split}\end{align}
\begin{align}\begin{split}
    I_{\text{ReNWJ}}(X, & Y) := \mathbb{E}_{\mathbb{P}_{X Y}^{(n)}}(T_\theta(x, y)) \\
     & - \mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y^{(n)}}(e^{T_\theta(x, y)-1}) \\
     & - \lambda (\log(\mathbb{E}_{\mathbb{P}_X^{(n)} \otimes \mathbb{P}_Y^{(n)}}(e^{T_\theta(x, y)-1})))^2,
\end{split}\end{align}
where $C^* \in \mathbb{R}$ is any constant and $\lambda$ is a hyperparameter that controls the degree of regularization.
We can also easily regularize other losses such as $I_\text{InfoNCE}$, $I_\text{SMILE}$, $I_\text{TUBA}$, and $I_\text{JS}$ in a plug-and-play manner.
See \cref{table:loss_setting} for more details on its regularized counterparts.

\paragraph{Solving the drifting problem}
Due to the self-regularizing nature of $D_\text{NWJ}$, we must fix $C^*=1$ for $I_\text{ReNWJ}$.
We also set $C^*=0$ for $I_\text{ReMINE}$ on future experiments, but to demonstrate the ability of the regularizer term to stop the drifting, we experiment with various $C^*$ in \cref{figs:remine_different_c}.
Comparing to $I_\text{MINE}$ in \cref{figs:onehot_toy_stable_loss}, we can observe that $I_\text{ReMINE}$ successfully solves the drifting problem by regularizing the second term to have a single value.

\begin{figure}[t] %%% t: top, b: bottom, h: here
\centering
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/remine-batch32-plot.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/renwj-batch32-plot.pdf}}\\
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/remine-batch32-scatter.pdf}}
\subfloat{\includegraphics[width=0.49\columnwidth]{asset/renwj-batch32-scatter.pdf}}
\caption{
    Training $T_\theta$ using the regularized counterparts, $I_\text{ReMINE}$ and $I_\text{ReNWJ}$, with the same small batch settings from \cref{figs:onehot_toy_unstable}.
    Regularization effectively mitigates both instability symptoms, shifting and exploding.
}
\label{figs:onehot_toy_remine}
\end{figure}

\paragraph{Solving the explosion problem}
We previously observed the instability of $I_\text{MINE}$ and $I_\text{NWJ}$ when using a small batch in \cref{figs:onehot_toy_unstable}.
We apply the same setting to $I_\text{ReMINE}$ and $I_\text{ReNWJ}$ to observe if the regularizer mitigates the instability problem.
Both regularized losses successfully avoid the explosion problem and limit the statistics network outputs $T_\theta(x_1, x_2)$ within a certain boundary.
As discussed in \cref{sec:motivating_experiment}, the second term was the culprit of the variance in MI estimation.
The newly added term directly regularizes it to stabilize training, giving the statistics network $T_\theta$ additional hints for the second term to converge to a specific value $C^*$ successfully.
Furthermore, we empirically found that our regularization works well with $I_\text{SMILE}$’s strategy of clipping $T_\theta$.
Gradient zeros out for the original $I_\text{SMILE}$ if $T_\theta(x, y)$ exceeds a certain threshold.
This behavior makes $T_\theta$ act as if it were frozen, failing to further optimize during training.
However, with the regularizer term, we can clip $T_\theta(x, y)$ only on first and second term, i.e., on the original loss.
Now, clipping filters out the noisy gradients while the gradients calculated from the regularizer avoid freezing $T_\theta$ entirely.

\paragraph{Mathematical properties of $I_\text{ReMINE}$ and $I_\text{ReNWJ}$}
Following \cite{belghazi2018mutual}, we show the soundness of $I_\text{ReMINE}$ and $I_\text{ReNWJ}$ in two perspectives, strong consistency and sample complexity.
These properties relate to whether the trained $T_\theta$ can be sufficiently similar to the optimal $T^*$.

\begin{thm}\label{thm:remine_strong_consistency} $I_\text{ReMINE}$ and $I_\text{ReNWJ}$ are strongly consistent. \end{thm}

For the two losses, we also provide the mathematical bound on the number of samples required for the empirical MI estimation at a given accuracy and with high confidence.
Similar to \cite{belghazi2018mutual}, let $T_\theta$ satisfy $L$-Lipschitz with respect to the parameter $\theta$ such that $|\theta|<K$ and $d$ is dimension of the parameter space of $T_\theta$.
\begin{thm}\label{thm:sample_complexity} Assume that $T_\theta$ is bounded above by $M$. Let $k$ be the number of sample means. Given any $\epsilon, \delta$ of the desired accuracy and confidence parameters, we have
\begin{equation}
\mathcal{P}(|I_{\text{ReMINE}}(X;Y) - I(X, Y)| \leq \epsilon) \geq 1 - \delta,
\end{equation}
whenever the number $n$ of samples satisfies
\begin{equation}
n \geq \frac{d\log(24KL\sqrt{d}/\epsilon) + 2dM + log(2/\delta)}{\epsilon^2k/(2M^2)}.
\end{equation}
\end{thm}

\begin{thm}\label{thm:nwj_sample_complexity} Assume that $1 \leq |T_\theta| < M$ and $d(x, 1) \leq |x - 1|$. Let $k$ be the number of sample means. Given any $\epsilon, \delta$ of the desired accuracy and confidence parameters, we have 
\begin{equation}
\mathcal{P}(|I_{\text{ReNWJ}}(X;Y) - I(X, Y)| \leq \epsilon) \geq 1 - \delta,
\end{equation}
whenever the number $n$ of samples satisfies
\begin{equation}
n \geq \frac{d\log(24KL\sqrt{d}/\epsilon) + 2dM + \log(2/\delta)}{\epsilon^2k/(2M^2)}.
\end{equation}
\end{thm}

\paragraph{Drifting may lead to noisy MI estimate}
We prove that the variance of the second term on the empirical distributions is affected by the constant term $C^*$.
\begin{thm}\label{thm:remine_variance} 
Let $Q^{(n)}$ be the empirical distributions of $n$ i.i.d. samples from $\mathbb{Q}$. For the optimal $T_{1} = \log\frac{dp}{dq} + C_{1}$ and $T_{2} = \log\frac{dp}{dq} + C_{2}$ where $C_1 \geq C_2$,
\begin{equation}
\text{Var}_{\mathbb{Q}}(\mathbb{E}_{\mathbb{Q}^{(n)}}(e^{T_1})) \geq \text{Var}_{\mathbb{Q}}(\mathbb{E}_{\mathbb{Q}^{(n)}}(e^{T_2})) 
\end{equation}
\end{thm}
This implies that unregulated $C^*$ may lead to worse MI estimation quality, as the source of the estimate variance are mainly due to the second term.

\paragraph{Increasing the effective sample size for MI estimation}
The drifting problem caused by the unnormalized constant term $C^*$ raises more issues when estimating MI.
\cite{poole2019variational} use a simple macro-averaging technique, i.e., averaging the estimated MI from each batch.
We can also consider a slight modification to the technique, where we call it the micro-averaging technique, by saving all the statistics network outputs $T_\theta(x, y)$ for each batch and producing a single estimate based on all the outputs.
However, we proved that both averaging techniques yield wrong final estimates for biased estimators like $I_\text{MINE}$ \citep{belghazi2018mutual}, $I_\text{SMILE}$ \citep{Song2020Understanding}, $I_\text{CLUB}$ \citep{pmlr-v119-cheng20b}, and $I_\text{InfoNCE}$ \citep{oord2018representation}.
\begin{thm}(Estimation bias caused by drifting) \label{thm-estim-bias}
Both macro- and micro-averaging strategies produce a biased MI estimate when the drifting problem occurs.
\end{thm}
To the contrary, self-normalizing or regularized MI estimators have the upper hand in this perspective.
By utilizing all the samples from multiple batches, they can effectively sidestep the batch size limitation problem \citep{mcallester2018formal, Song2020Understanding}.

\begin{table*}[ht]
\centering
\begin{tabular}{c|cc}
\Xhline{1.2pt}
Loss & Loss settings & Regularizer settings \\ \hline
\hline
$I_{\text{MINE}}$    & No gradient moving average   & Euclidean distance \\
$I_{\text{SMILE}}$   & Clipping $(\tau = 10)$        & Euclidean distance \\
$I_{\text{InfoNCE}}$ & -                            & Euclidean distance \\
$I_{\text{NWJ}}$     & -                            & Log-Euclidean distance, Clipping $(\tau = 10)$ \\
$I_{\text{TUBA}}$    & $a(y) = 1$                   & Log-Euclidean distance, Clipping $(\tau = 10)$ \\
$I_{\text{JS}}$      & Estimate with $I_\text{NWJ}$ & Euclidean distance \\ \hline
\Xhline{1.2pt}
\end{tabular}
\caption{List of MI estimators with its hyperparameters}
\label{table:loss_setting}
\end{table*}

\section{Experiments} \label{sec}
\subsection{MI Estimation vs. Downstream Task Performance} \label{sec:benchmark}
\subparagraph{Benchmark Design}
To measure the performance of the MI estimators, one must design the target task to have the ground truth MI.
This constraint led previous works to evaluate the estimators only on artificial toy problems \citep{belghazi2018mutual, poole2019variational}, where its connection to actual problems is fairly limited.
We design the two types of MI estimation tasks with de facto image datasets to improve the existing benchmarks to reflect on the real-world tasks.
We defer all the proofs to the Appendix.

\begin{thm}\label{thm_slb} (Supervised learning) Given a dataset $D = (X, Y)$ where $X$ is an sample, $Y$ is the label for $X$, and $H(Y)$ is the entropy of the label set, $I(X, Y) = H(Y)$.
\end{thm}

Similarly, the true MI between images from the same class is also tractable based on the same assumption.
\begin{thm}\label{thm_clb} (Contrastive learning) Consider the dataset $D = (X, Y)$. Let $X_1$ be a sample drawn from the dataset and $X_2$ be another sample drawn from the subset with the same label $Y$ to $X_1$.
Then, $I(X_1, X_2) = I(X_1, Y) = I(X_2, Y) = H(Y)$.
\end{thm}
Note that we assume statistical dependence between the image $X$ and label $Y$ from the point of view of information bottleneck \citep{tishby2015deep}.
We derive the theorems above based on the assumption, where $Y$ implicitly determines $X$.

Based on the above theorems, we use the two MI estimation problems as benchmarks that evaluate the performance of estimators.
We intentionally design \cref{thm_slb} and \cref{thm_clb} to mimic the existing tasks closely, namely, supervised and contrastive learning.
For \cref{thm_slb}, we can set the statistics network $T_\theta(X, Y) = f_\theta(X) \cdot o(Y)$ where $f_\theta(X)$ is the logits obtained from feeding the image $X$ to the classification neural networks and $o(Y)$ is the one-hot representation of the label $Y$.
If we use the InfoNCE estimator, this formulation becomes identical to solving the classification problem using negative log loss with the Softmax function, hence the name being the supervised learning benchmark (SLB).
Similarly, for \cref{thm_clb}, we can set $T_\theta(X_1, X_2)=f_\theta(X_1) \cdot f_\theta(X_2)$ and use the InfoNCE estimator to yield a commonly used contrastive loss \citep{oord2018representation, chen2020simple}.

Due to the strict assumption of statistical dependence, the theorems above cannot be used on standard datasets like ImageNet dataset \citep{DBLP:conf/cvpr/DengDSLL009}, as its samples often violate the single-label assumption.
However, we can still empirically compare the MI estimators by the relative size of their final MI estimation.
We conduct a demo experiment on ImageNet in the Appendix.

\paragraph{Evaluation}
\begin{table*}[t]
\centering

\begin{tabular}{c|c|c|cc|cc}
\Xhline{1.2pt}
\multicolumn{2}{c|}{\multirow{2}{*}{Task}}  & \multirow{2}{*}{Loss} & \multicolumn{2}{c|}{MI Estimation} & \multicolumn{2}{c}{Test Accuracy} \\ 
\multicolumn{2}{c|}{}& & Original & Regularized & Original & Regularized \\
\hline

\multicolumn{1}{c|}{\multirow{12}{*}{
\rotatebox[origin=c]{90}{Supervised Learning Benchmark}
}} & \multirow{6}{*}{
\rotatebox[origin=c]{90}{CIFAR-10}}
 & MINE & \textbf{2.300 \footnotesize $\pm$ 0.003} & 2.298 \footnotesize $\pm$ 0.005 & 0.850 \footnotesize $\pm$ 0.009 & \textbf{0.856 \footnotesize $\pm$ 0.004} \\
& \multicolumn{1}{c|}{} & SMILE & 2.297 \footnotesize $\pm$ 0.009 & \textbf{2.300 \footnotesize $\pm$ 0.003} & \textbf{0.854 \footnotesize $\pm$ 0.008} & 0.853 \footnotesize $\pm$ 0.009 \\
& \multicolumn{1}{c|}{} & InfoNCE & 2.301 \footnotesize $\pm$ 0.002 & \textbf{2.302 \footnotesize $\pm$ 0.001} & 0.845 \footnotesize $\pm$ 0.006 & 0.845 \footnotesize $\pm$ 0.005 \\
& \multicolumn{1}{c|}{} & NWJ & \textbf{2.297 \footnotesize $\pm$ 0.009} & 2.294 \footnotesize $\pm$ 0.013 & 0.859 \footnotesize $\pm$ 0.003 & \textbf{0.862 \footnotesize $\pm$ 0.004} \\
& \multicolumn{1}{c|}{} & TUBA & 2.297 \footnotesize $\pm$ 0.008 & \textbf{2.300 \footnotesize $\pm$ 0.003} & \textbf{0.862 \footnotesize $\pm$ 0.008} & 0.859 \footnotesize $\pm$ 0.003 \\
& \multicolumn{1}{c|}{} & JS & 1.944 \footnotesize $\pm$ 0.039 & \textbf{2.000 \footnotesize $\pm$ 0.049} & 0.838 \footnotesize $\pm$ 0.012 & \textbf{0.842 \footnotesize $\pm$ 0.004} \\
\cline{2-7}

& \multicolumn{1}{c|}{\multirow{6}{*}{
\rotatebox[origin=c]{90}{CIFAR-100}
}} & MINE & 4.597 \footnotesize $\pm$ 0.011 & \textbf{4.603 \footnotesize $\pm$ 0.001} & 0.610 \footnotesize $\pm$ 0.007 & 0.610 \footnotesize $\pm$ 0.006 \\
& \multicolumn{1}{c|}{} & SMILE & 4.595 \footnotesize $\pm$ 0.015 & \textbf{4.602 \footnotesize $\pm$ 0.002} & 0.601 \footnotesize $\pm$ 0.015 & \textbf{0.606 \footnotesize $\pm$ 0.007} \\
& \multicolumn{1}{c|}{} & InfoNCE & 4.594 \footnotesize $\pm$ 0.017 & \textbf{4.599 \footnotesize $\pm$ 0.005} & 0.589 \footnotesize $\pm$ 0.010 & \textbf{0.593 \footnotesize $\pm$ 0.005} \\
& \multicolumn{1}{c|}{} & NWJ & 4.572 \footnotesize $\pm$ 0.055 & \textbf{4.586 \footnotesize $\pm$ 0.034} & 0.558 \footnotesize $\pm$ 0.042 & \textbf{0.599 \footnotesize $\pm$ 0.009} \\
& \multicolumn{1}{c|}{} & TUBA & 4.495 \footnotesize $\pm$ 0.207 & \textbf{4.603 \footnotesize $\pm$ 0.002} & 0.543 \footnotesize $\pm$ 0.055 & \textcolor{blue}{\textbf{0.611 \footnotesize $\pm$ 0.007}} \\
& \multicolumn{1}{c|}{} & JS & 4.088 \footnotesize $\pm$ 0.430 & \textbf{4.240 \footnotesize $\pm$ 0.116} & 0.591 \footnotesize $\pm$ 0.026 & \textbf{0.598 \footnotesize $\pm$ 0.010} \\

\hline
\multicolumn{1}{c|}{\multirow{12}{*}{
\rotatebox[origin=c]{90}{Contrastive Learning Benchmark}
}} & \multicolumn{1}{c|}{\multirow{6}{*}{
\rotatebox[origin=c]{90}{CIFAR-10}
}} & MINE & 2.233 \footnotesize $\pm$ 0.674 & \textbf{2.240 \footnotesize $\pm$ 0.657} & 0.812 \footnotesize $\pm$ 0.026 & \textbf{0.823 \footnotesize $\pm$ 0.012} \\
& \multicolumn{1}{c|}{} & SMILE & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{2.065 \footnotesize $\pm$ 0.842}} & 0.100 \footnotesize $\pm$ 0.001 & \textcolor{blue}{\textbf{0.830 \footnotesize $\pm$ 0.008}} \\
& \multicolumn{1}{c|}{} & InfoNCE & 1.705 \footnotesize $\pm$ 0.462 & \textbf{1.739 \footnotesize $\pm$ 0.431} & \textbf{0.830 \footnotesize $\pm$ 0.008} & 0.826 \footnotesize $\pm$ 0.006 \\
& \multicolumn{1}{c|}{} & NWJ & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{1.910 \footnotesize $\pm$ 0.662}} & 0.100 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{0.831 \footnotesize $\pm$ 0.005}} \\
& \multicolumn{1}{c|}{} & TUBA & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{1.358 \footnotesize $\pm$ 0.590}} & 0.100 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{0.830 \footnotesize $\pm$ 0.009}} \\
& \multicolumn{1}{c|}{} & JS & 1.552 \footnotesize $\pm$ 0.485 & \textbf{1.556 \footnotesize $\pm$ 0.546} & \textbf{0.837 \footnotesize $\pm$ 0.003} & 0.832 \footnotesize $\pm$ 0.009 \\
\cline{2-7}
& \multicolumn{1}{c|}{\multirow{6}{*}{
\rotatebox[origin=c]{90}{CIFAR-100}}
} & MINE & \textbf{4.634 \footnotesize $\pm$ 0.186} & 4.563 \footnotesize $\pm$ 0.162 & 0.522 \footnotesize $\pm$ 0.026 & \textbf{0.540 \footnotesize $\pm$ 0.020} \\
& \multicolumn{1}{c|}{} & SMILE & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{4.677 \footnotesize $\pm$ 0.162}} & 0.012 \footnotesize $\pm$ 0.003 & \textcolor{blue}{\textbf{0.585 \footnotesize $\pm$ 0.007}} \\
& \multicolumn{1}{c|}{} & InfoNCE & 4.112 \footnotesize $\pm$ 0.147 & \textbf{4.115 \footnotesize $\pm$ 0.145} & 0.576 \footnotesize $\pm$ 0.019 & \textcolor{blue}{\textbf{0.585 \footnotesize $\pm$ 0.014}} \\
& \multicolumn{1}{c|}{} & NWJ & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{4.065 \footnotesize $\pm$ 0.255}} & 0.010 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{0.521 \footnotesize $\pm$ 0.025}} \\
& \multicolumn{1}{c|}{} & TUBA & 0.000 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{2.731 \footnotesize $\pm$ 0.786}} & 0.010 \footnotesize $\pm$ 0.000 & \textcolor{blue}{\textbf{0.490 \footnotesize $\pm$ 0.023}} \\
& \multicolumn{1}{c|}{} & JS & 3.253 \footnotesize $\pm$ 0.368 & \textbf{3.393 \footnotesize $\pm$ 0.124} & 0.451 \footnotesize $\pm$ 0.020 & \textbf{0.463 \footnotesize $\pm$ 0.031} \\
\hline\Xhline{1.2pt}
\end{tabular}
\caption{Our supervised and contrastive learning benchmark results.
We provide the 95\% confidence interval of 5 runs for both MI estimation and test accuracy, where we clip the negative MI estimations to 0.
We compare the performance of original and regularized loss.
\textbf{Bold text} and \textcolor{blue}{\textbf{blue text}} indicates the better performance with overlapping and non-overlapping confidence interval, respectively.
}
\label{table:sota_benchmark}
\end{table*}
To verify the performance of MI estimators, we perform our benchmark tasks on the CIFAR10 and CIFAR100 dataset \citep{krizhevsky2009learning}. As both CIFAR10 and CIFAR100 have a uniform label distribution, ideal MI is $\log 10$ and $\log 100$, respectively. 
In addition, to check whether this MI estimate task is actually helpful for downstream tasks, we evaluate each estimator on both dimensions: MI estimation and test set accuracy.
Similar to the existing settings in the contrastive learning literature \citep{chen2020simple, he2020momentum}, we design the test accuracy of CLB by defining the label estimate $\hat{y}$ of each test set sample $x_\text{Test}$ to be the label of $x = \text{argmax}_{x\in X_{\text{Train}}} f(x) \cdot f(x_{\text{Test}})$ of the train dataset $X_\text{Train}$.
Similarly for SLB, we chose $\hat{y} = \text{argmax}_y f(x_\text{Test}) \cdot o(y)$ where $o(y)$ is the one-hot encoding of $y$.
We ran the same experiment $5$ times with different seeds to yield a 95\% confidence interval.

\subsection{Comparison with Our Benchmark}\label{subsec:comparison}
To demonstrate the effectiveness of our novel regularization term, we regularize the two representations, $D_\text{DV}$ and $D_\text{NWJ}$.
We test three realizations for each representation, $I_\text{MINE}$ \citep{belghazi2018mutual}, $I_\text{SMILE}$ \citep{Song2020Understanding}, and $I_\text{InfoNCE}$ \citep{oord2018representation} for $D_\text{DV}$, $I_\text{NWJ}$ \citep{nguyen2010estimating}, $I_\text{TUBA}$ \citep{poole2019variational}, and $I_\text{JS}$ \citep{hjelm2019learning} for $D_\text{NWJ}$.
We compare the original losses with its regularized counterparts, a total of $6\times2=12$.
We do not apply averaging scheme on any of the losses, and choose the regularization weight $\lambda \in \{0.1, 0.01, 0.001\}$ that shows the best MI estimation results.
See \cref{table:loss_setting} for more details.


We observe in \cref{table:sota_benchmark} that additional regularization generally induces better performance on both the MI estimation task and the downstream task (test accuracy).
Hence, adding the regularizer to a pre-existing supervised or contrastive learning loss seems to be a viable option to increase the performance further.
Even when the performance of the regularized loss slightly degrades, its negative impact is minimal.
This implies that even for the case where the regularizer is not greatly helpful, it does not greatly hinder optimization.
Especially, it is intriguing that many losses, $I_\text{MINE}$, $I_\text{ReMINE}$, and $I_\text{ReTUBA}$, are better than $I_\text{InfoNCE}$ in SLB, which is used as the de facto standard in classification.
Also, $I_\text{SMILE}$, $I_\text{NWJ}$, and $I_\text{TUBA}$ fail to converge in CLB, where simply adding a regularization term solves the issue altogether, yielding a competitive or even better performance than all the other losses.
Given the fact that numerous contrastive learning literature suffers from instability \citep{caron2021emerging,DBLP:journals/corr/abs-2105-04906,chen2020simple, he2020momentum, DBLP:journals/corr/abs-2105-04906}, we emphasize that adding our regularization term can be a simple yet effective method to stabilize training.

\begin{figure}[ht] %%% t: top, b: bottom, h: here
\centering
    \includegraphics[width=0.95\columnwidth]{asset/benchmark_ablation.pdf} 
\caption{
    Ablation study on different $\lambda$s with CLB CIFAR-10 and CIFAR-100.
}
\label{figs:our_benchmark_ablation}
\end{figure}
Additionally, to observe the impact of regularization strength $\lambda$, we plot the benchmark performance for each $\lambda$ in \cref{figs:our_benchmark_ablation}.
We compare the losses on CLB as experimental results suggest that CLB is a more difficult task than SLB, showing significant performance differences between various losses.
On CIFAR-10, $\lambda$ acts as a trade-off parameter between test accuracy and MI estimation quality.
Performance trade-off has also been reported in other literature, where better MI estimation does not necessarily deliver better downstream performance \citep{DBLP:conf/iclr/TschannenDRGL20,DBLP:conf/nips/Tian0PKSI20}.
However, compared to CIFAR-100, test accuracy differences are minimal, where MI differences are apparent.
$I_\text{ReMINE}$ and $I_\text{ReSMILE}$ show excellent MI estimation quality in CIFAR-10 compared to other losses.
In contrast, test accuracy and MI estimation quality align well in the CIFAR-100 case.
$I_\text{ReSMILE}$ shows good overall performance, albeit its sensitivity towards regularization strength.
$I_\text{ReInfoNCE}$, on the other hand, shows stable performance in the downstream task, sacrificing the MI estimation quality.
This result is further supported by the prominence of $I_\text{InfoNCE}$ in the contrastive learning domain.
It is yet unclear where the difference between CIFAR-10 and CIFAR-100 comes from, whether it is due to the difference in the level of difficulty of the dataset or the batch size used throughout the training.
We leave further analysis as future work.

\subsection{Comparison with the Standard Toy Problem}\label{section5-sub2}
We provide the quality of MI-based losses on the 20D Correlated Gaussian task \citep{belghazi2018mutual, poole2019variational} where the true MI is increased $5$ times during optimization in \cref{figs:gaussian_benchmark}.
This experiment demonstrates how stable the MI-based losses estimate MI in a dynamically changing environment.
We apply the same settings from \cref{table:loss_setting}, where we fix the regularization strength $\lambda = 1.0$ for all the losses.
With the exception of $I_\text{InfoNCE}$, regularized losses show clear superiority over the original losses.
Regularization facilitates $I_\text{MINE}$ and $I_\text{SMILE}$ to avoid the instability which is mentioned in \cref{sec:motivating_experiment}.
Also, regularization greatly enhances the MI estimation quality of $I_\text{JS}$ and lessens the variance of both $I_\text{NWJ}$ and $I_\text{TUBA}$.

\section{Conclusion} \label{sec:conclusion}
In this paper, we identify the two symptoms behind the instability: The statistics network was not converging even after the loss seemed to converge, and its outputs from the product of marginal distribution explode during training.
We propose a novel regularization term to mitigate the instability during training by adding to various existing MI-based losses.
We theoretically and experimentally demonstrate that the added regularizer directly alleviates the two instability symptoms.
Finally, we present a benchmark that evaluates both the MI estimation power and its capability on the downstream tasks by imitating the supervised or contrastive learning settings.
We compare six different losses and their regularized counterparts on various benchmarks to show the method's effectiveness and broad applicability.
\begin{figure}[H] 
\centering
\includegraphics[width=0.95\columnwidth]{asset/gaussian.pdf}
% %\includegraphics[height=8.7cm]{asset/fig-benchmark.png}
\caption{
Estimation performance on 20-D Gaussian.
The estimated MI (light) and the smoothed estimation with exponential moving average (dark) are plotted for each methods with its regularized counterparts.
Black line represents the true MI.
Dotted line shows the bound of $I_\text{InfoNCE}$ due to the limited batch size of 64.
}
\label{figs:gaussian_benchmark}
\end{figure}

\section*{Limitations and Future works}\label{sec:limitations}
We suspect that the instability of MI estimators can also be related to the collapse problem \citep{DBLP:journals/corr/abs-2105-04906,caron2021emerging}.
Further loss-based approaches to combat this problem by regularizing the network outputs may be helpful.
We expect that extending our methods to various contrastive learning losses may yield fruitful results for self-supervised learning, notably for other domains such as text or audio.
Also, our mathematical analysis is mainly focused on the drifting problem of $I_\text{MINE}$, not the explosion problem of $I_\text{NWJ}$.
For $I_\text{NWJ}$, we suspect that the absence of the log function wrapping the exponential values makes the second term much more susceptible to output explosion due to its numerical instability.
The added regularizer gives additional hints for the second term to converge to a specific value.
However, we did not expand the discussion further in this paper.

\bibliography{choi_500}

\end{document}
