%!TEX root = ../sublime-text.tex

\label{sec:experiments}

\begin{table*}[ht!]
\begin{center}
\caption{Top-1 test accuracy of ResNet-18 with timings on MNISTCIFAR. Our method, IRMv1+IHT, bolded, has negligible overhead time cost and the overall best test accuracy. 
}
\label{tab:resnet18}
\begin{tabular}{lllll}
\\
\toprule
Method      & Test Accuracy & Train Time (s)     & \% time/ERM &  $L_1$ norm of last layer \\
\midrule
Oracle      & 77.85 $\pm$  0.14 & 36.38 $\pm$  0.26 & 99\% & 19.72 $\pm$ 3.88 \\
\midrule
ERM         & 44.93 $\pm$ 0.49    & 36.65 $\pm$  0.25 &  - \% &25.05 $\pm$ 2.32   \\
Sparse ERM  & 44.82 $\pm$ 0.42   & 37.31 $\pm$  0.37 & 102\%  & 18.31 $\pm$ 2.46        \\
\midrule
IRMv1       & 52.86 $\pm$ 0.53  & 36.51 $\pm$  0.17 & 100\% &  18.65 $\pm$ 0.93 \\
IRMv1+PM   & 57.30 $\pm$ 0.45  & 44.98 $\pm$  1.02 & 123\%   &  21.59 $\pm$ 1.02    \\
% IRMv1+L1 & 59.01 $\pm$ 1.18 & 40.29 $\pm$  0.42 & 110\% & 13.97 $\pm$ 1.30 \\
\textbf{IRMv1+IHT}  & \textbf{62.44 $\pm$ 0.96}  & \textbf{38.03 $\pm$  0.51} & \textbf{104\% }    & \textbf{9.10 $\pm$ 1.78}       \\ \bottomrule
\end{tabular}
\end{center}
\end{table*}
% \jdcomment{IRM-PGD for L1 didn't have good test results; I more or less switched to IHT, but it does constrain the L1 norm as well.}

\textbf{Algorithms:}
We compare our approach, IRM with iterative hard thresholding (IRMv1 + IHT), with relevant baselines ERM, sparse ERM, the oracle, and IRM-based methods.
For IRM-based methods, we use IRMv1 \citep{arjovsky2020invariant}, and we provide \Cref{prop:lossdiff} to prove it is an acceptable proxy for the minimax formulation in \Cref{eqn:irm-minimax-empirical}.
In order, ERM is the standard training loop on the mixture of all environments; and sparse ERM adds IHT \citep{jain_iterativehardthreshold_2014}. 
The oracle trains ERM with spurious features zeroed, upper bounding accuracies for other methods. 
For the IRM-based methods, we compare with the original IRMv1 \citep{arjovsky2020invariant}, and IRMv1 with ProbMask (IRMv1+PM) \citep{zhouSparseInvariantRisk2022, zhou2021effective}.
When comparing sparsity-based methods, we fix the target density of the feature representation to be same across methods. 
% The oracle baseline is constructed per dataset, and is only given access to the designated invariant features. In 2-CMNIST and 10-CMNIST, this is the original MNIST classification task, removing the spurious colors. In MNISTCIFAR, this whites out the MNIST image, leaving the invariant CIFAR component.

\textbf{Datasets:} 
We use common invariant representation learning benchmarks,
% and standard experimental configurations for these datasets. 
ColoredMNIST (2-CMNIST) is the original binary dataset introduced in \citet{arjovsky2020invariant}, and FullColoredMNIST (10-CMNIST) \citep{ahmed2021-predictivegroupinvariance} is also generated from MNIST, with two environments, 10 labels and 10 colors. 
MNISTCIFAR concatenates MNIST digits and CIFAR-10 images \citep{shah2020-simplicitybias}. 
% In each dataset, the label is generated from the invariant features before being corrupted with some label noise. The spurious features are correlated strongly with the label in both training environments, and this correlation is flipped in the test environment. 
The oracle baseline is constructed per dataset and only has the designated invariant features: the grayscale MNIST for 2- and 10-CMNIST, and the CIFAR image for MNISTCIFAR.
Parameters for the dataset configurations, including label noise and environmental correlation, are in \Cref{appx:experiments}.
% For both CMNIST variants, we train a MLP with two hidden layers of dimension 390, the median configuration of the model used by \citep{zhouSparseInvariantRisk2022} on these datasets. We show results for ResNet-18 on the MNISTCIFAR dataset, for which we provide additional timing results.
\begin{table}[t]
\begin{center}
\caption{Top-1 train and test accuracy of MLP390.\\}
\label{tab:mlp390}
\begin{tabular}{lllll}
\\
\toprule
               &  \multicolumn{2}{l}{10-CMNIST Accuracy (\%)} \\
\midrule
Method         &Train & Test  \\
\midrule
Oracle         & 73.06 $\pm$ 0.21    &  71.36 $\pm$ 0.44 \\
% \midrule
ERM            &  90.00 $\pm$ 0.29  & 28.32 $\pm$ 0.10   \\
Sparse ERM     & 87.17 $\pm$ 1.16  & 29.15 $\pm$ 2.14  \\
\midrule
IRMv1          & 70.77 $\pm$ 0.27 & 58.88 $\pm$ 0.14         \\
\textbf{IRMv1+PM} & \textbf{92.20 $\pm$ 0.10 }    & \textbf{65.16 $\pm$ 0.09}           \\
\textbf{IRMv1+IHT }     &     \textbf{80.83 $\pm$ 0.10 }  & \textbf{63.03 $\pm$ 0.51}  \\
\bottomrule
\end{tabular}
\end{center}
\end{table}

{
% \color{blue}
\textbf{Hyperparameter selection:} Because we do not know $d_\inv$ at train time, it is common to treat $s$ in \cref{algorithm} as a hyperparameter as in e.g.\ \citep{Wainwright2019-tb}. Specifically, we take a uniform grid search per dataset. We find also that accuracy is not affected significantly by small perturbations in $s$, which is demonstrated by data from additional experiments on MNISTCIFAR in \Cref{tab:hparams}.
% We also set the IRM penalty weight the same hyperparameter for the IRM penalty as used in \citep{arjovsky2020invariant,zhouSparseInvariantRisk2022}.  
}

\textbf{Evaluation metrics:}
Top-1 test accuracy is compared for the three tasks. For ResNet-18 on MNISTCIFAR, we also provide training time results, and the relative timing in comparison to standard ERM. 
% \jdcomment{Shoud I take the aboslute timing values out of the table and leave just the percentages?}



% \begin{table}[h]
% \label{tab:mlp390}
% \begin{center}
% \begin{tabular}{lllll}
% \toprule
%                & \multicolumn{2}{l}{2-CMNIST Accuracy (\%)} & \multicolumn{2}{l}{10-CMNIST Accuracy (\%)} \\
% \midrule
% Method         & Train       & Test                     & Train & Test  \\
% \midrule
% Oracle         &                   &                   &             &            \\
% % \midrule
% ERM            &  &  &   90.00 $\pm$ 0.03  & 28.32 $\pm$ 0.10   \\
% Sparse ERM     &                   &                   &           &           \\
% \midrule
% IRMv1          & 81.03 $\pm$ 0.05  & 57.09 $\pm$ 0.07 &  70.77 $\pm$ 0.27 & 58.88 $\pm$ 0.14         \\
% IRMv1 + ProbMask & 75.81 $\pm$ 0.22  & 57.52 $\pm$ 0.38 &                       & \textbf{65.16 $\pm$ 0.09}           \\
% IRMv1 + IHT      &                   &                    & 80.83 $\pm$ 0.01   & 63.03 $\pm$ 0.05  \\
% \bottomrule
% \end{tabular}
% \caption{Top-1 train and test accuracy of MLP390.}
% \end{center}
% \end{table}

% \jdcomment{2-CMNIST numbers are very unusual. Both baselines and my results are two low; will hide those columns}


\textbf{Discussion:}
We observe that IRM with IHT can match or exceed the performance of competing methods, including IRM with ProbMask sparsity, for larger models and datasets. 
Sparse ERM, IRMv1+PM, and IRMv1+IHT were computed with 88\% weight density in \Cref{tab:resnet18}; this corresponds to 12\% of the weights zeroed out by sparsificaiton methods.
The $L_1$ norms of the layer also reflect the sparsification. 
ProbMask incurs a noticeable computational overhead -- an additional 23\% over IRMv1. IHT only adds a 4\% cost. We expect time savings to scale up with larger models. 
Additionally, we provide results for a MLP with two hidden layers of dimension 390, the median configuration of the model used by \citep{zhouSparseInvariantRisk2022} on these datasets. 

% For both CMNIST variants, we train a MLP with two hidden layers of dimension 390, the median configuration of the model used by \citep{zhouSparseInvariantRisk2022} on these datasets. We show results for ResNet-18 on the MNISTCIFAR dataset, for which we provide additional timing results.


% The tests are quite quick, taking advantage of the smaller size of both dataset and model.
% We have likewise optimized the training to best isolate the effect of the sparsification method. 
% \jdcomment{we expect these time savings to scale up with larger models.}
% The CMNIST datasets, while explored extensively in the IRM literature, are too small to note a change in computation time. 

% \begin{figure}[h]
% \begin{center}
% \caption{Test accuracy for IRMv1+PM and IRMv1+IHT trained on MNISTCIFAR with ResNet-18, across different global sparsity constraints. 
% \\}
% %\framebox[4.0in]{$\;$}
% % \fbox{\rule[-.5cm]{0cm}{4cm} \rule[-.5cm]{4cm}{0cm}}
% \includegraphics[width=6cm]{images/densityvsacc.png}
% \end{center}
% \label{fig:sparsitycompare}
% \end{figure}

% In \Cref{fig:sparsitycompare}, we observe that IRM+ProbMask is insensitive to changes in the size of the subnet selected, while our IRMv1+IHT trends upwards with the dimensionality allowed in the last layer. For IRM+ProbMask, more than a small number of zeroed weights leads to overfitting on training sets for standard IRM datasets like CMNIST and MNISTCIFAR. 
% % \jdcomment{They peak at about 5\% sparsity. Unfortunately, mine peaks at only about 10-15\% sparsity, which is not a lot better. } 
% This suggests that sparse representation is helpful to combat overparameterization, and that a sparse model for dense representation is only reflecting the usual benefits to robustness provided by any sparse model. 
% Comparatively, for IRMv1+IHT, we see that 
% % \jdcomment{??? where am I going with this}



