\begin{table*}[t]
%
    \centering
    \caption{Test accuracy (\%) with standard error over three datasets. Each experiment is repeated over 5 seeds.}
\vspace{-0.5em}
    \centering
    \resizebox{0.8\textwidth}{!}{%
        \begin{tabular}{@{}lcccccc@{}}
            \toprule
            \textbf{Method} & \multicolumn{2}{c}{\textbf{Waterbirds (CLIP ViT-B/32)}} & \multicolumn{2}{c}{\textbf{CelebA (CLIP ViT-B/32)}} & \multicolumn{2}{c}{\textbf{MultiNLI (BERT)}}                                                                         \\
            \cmidrule(l){2-3} \cmidrule(l){4-5} \cmidrule(l){6-7}
                            & \textbf{Average}                                        & \textbf{Worst-Group}                                & \textbf{Average}                             & \textbf{Worst-Group}  & \textbf{Average}      & \textbf{Worst-Group}  \\
            \midrule
            ERM & 89.52 ± 0.10 & 84.58 ± 0.20 & 78.76 ± 0.03 & 72.22 ± 0.39 & 81.15 ± 0.30 & 68.82 ± 0.64 \\ 
            CORAL & 89.67 ± 0.14 & 84.85 ± 0.22 & \textbf{78.81 ± 0.03} & 73.00 ± 0.22 & 81.22 ± 0.21 & 68.71 ± 0.52 \\
            Fishr & 89.79 ± 0.10 & 86.08 ± 0.10 & 73.95 ± 0.86 & 69.63 ± 1.20 & \textbf{81.35 ± 0.16} & \textbf{71.55 ± 1.20} \\  
            CMA & \textbf{90.11 ± 0.17} & \textbf{86.16 ± 0.29} & 77.87 ± 0.04 & \textbf{74.16 ± 0.10} & 81.30 ± 0.25 & 69.72 ± 0.66 \\
        \bottomrule
    \end{tabular}\label{tab:real-datasets-transposed}}
\vspace{-1em}
\end{table*}

\begin{table*}[t]
    \centering
    \caption{DomainBed results with test-domain validation model selection.}
    \vspace{-0.5em}
    \centering
    \resizebox{0.8\textwidth}{!}{%
\begin{tabular}{lcccccc}
\toprule
\textbf{Algorithm}        & \textbf{ColoredMNIST}     & \textbf{RotatedMNIST}     & \textbf{VLCS}             & \textbf{PACS}             & \textbf{TerraIncognita}   & \textbf{Avg}              \\
\midrule
ERM                       & 54.5 $\pm$ 0.2            & 97.8 $\pm$ 0.1            & 76.9 $\pm$ 0.3            & 80.2 $\pm$ 0.5            & 36.5 $\pm$ 0.5            & 69.2                      \\
CORAL                     & 55.7 $\pm$ 0.5            & \textbf{98.0 $\pm$ 0.0}            & 75.9 $\pm$ 0.2            & 80.2 $\pm$ 0.2            & 33.6 $\pm$ 0.5            & 68.7                      \\
Fishr                     & 62.0 $\pm$ 1.7            & 97.9 $\pm$ 0.0            & \textbf{77.5 $\pm$ 0.5}            & 81.5 $\pm$ 0.2            & 37.3 $\pm$ 1.1            & 71.2                      \\
CMA          & \textbf{62.5 $\pm$ 0.9}            & 97.9 $\pm$ 0.1            & 77.4 $\pm$ 0.8            & \textbf{81.6 $\pm$ 0.3}            & \textbf{38.4 $\pm$ 1.2}            & \textbf{71.5}                      \\
\bottomrule
\end{tabular}
}
\label{tab:domainbed}
\vspace{-1em}
\end{table*}

\vspace{-0.5em}
\section{Experiments}\label{sec:experiment}
\vspace{-0.5em}
We validate CMA through both quantitative and qualitative analyses. First, we describe the experimental setup, including dataset details and model training procedures. We then present quantitative results, evaluating CMA’s performance under the \hyperref[assumption:IRM]{IRM assumption} and scenarios where it does not hold. Finally, we conduct qualitative analyses to better understand CMA’s effect on worst-group performance and feature moment alignment.

\vspace{-1em}
\subsection{Implementation}\label{sec:implementation}
\vspace{-1em}
 %
\textbf{Linear Probing (IRM)}~~We evaluate liner probing performance on Waterbirds~\citep{sagawa_distributionally_2020}, CelebA~\citep{liu_deep_2015}, and MultiNLI~\citep{williams_broad-coverage_2018}. To enforce the \hyperref[assumption:IRM]{IRM assumption}, we apply the Invariant-feature Subspace Recovery (ISR) algorithm~\citep{wang_provable_2022, wang_invariant-feature_2023}, which provably yields features that induce an optimal invariant predictor. For Waterbirds and CelebA, we extract features from a CLIP-pretrained Vision Transformer (ViT-B/32). For MultiNLI, we fine-tune a BERT model using the code and hyperparameters in \citet{sagawa_distributionally_2020}, then extract features from the fine-tuned model. These features are transformed using the ISR-mean algorithm~\citep{wang_provable_2022,wang_invariant-feature_2023}. Finally, we train a linear classifier using ERM, Fishr~\citep{rame_fishr_2022}, and CMA objectives.

\textbf{Full Fine-Tuning (Non-IRM)}~~We run end-to-end fine-tuning on a subset of DomainBed~\citep{gulrajani_search_2020}, applying gradient and Hessian regularization from \Cref{eq:cma_loss} to the classifier head while back-propagating the loss through both the linear classifier and the encoder. Specifically, penalizing large gradient variance aligns gradients across domains, while the ERM loss drives gradients toward zero. The two mechanisms promote a small gradient norm for each domain, aligning with the theory in \Cref{sec:theory_no_irm} and \Cref{sec:unif_feature}. Given recent empirical evidence supporting strong DG capabilities of Vision Transformers~\citep{ghosal_are_2022, zheng_prompt_2022, sultana_self-distilled_2022}, we have selected ViT-S as the backbone for DomainBed experiments. Using the DomainBed codebase~\citep{gulrajani_search_2020}, we compare ERM~\citep{vapnik_overview_1999}, CORAL~\citep{sun_deep_2016}, Fishr~\citep{rame_fishr_2022}, and CMA by fine-tuning small Vision Transformers~\citep{steiner_how_2022, dosovitskiy_image_2021, wightman_pytorch_2019}. For more implementation details, please refer to \Cref{app:irm_exp} and \Cref{app:non_irm_exp}.
\vspace{-1em}
\subsection{Quantitative Results}
\vspace{-1em}

Our goal is not to claim that CMA surpassed existing algorithms but to demonstrate that our framework encompasses gradient matching (e.g., \citet{koyama_out--distribution_2020}) and Hessian matching (e.g., \citet{sun_deep_2016,rame_fishr_2022,hemati_understanding_2023}). To this end, our experimental results confirm that CMA achieves performance comparable to state-of-the-art moment matching methods.

\textbf{Linear Probing (IRM)} From \Cref{tab:real-datasets-transposed}, we observe that CMA consistently outperforms ERM on worst-group accuracy while maintaining comparable average accuracy across all datasets. Compared to Fishr, CMA achieves higher worst-group and average accuracy on two out of three datasets. In contrast, Fishr's performance varies, underperforming ERM on CelebA. Compared to CORAL, CMA achieves better worst-group performance across all datasets, while maintaining similar average accuracy. 

\textbf{Full Fine-Tuning (Non-IRM)} We follow \citet{rame_fishr_2022} to employ the test-domain model selection method, where the validation set is a holdout set from the test domain. As shown in \Cref{tab:domainbed}, CMA achieves comparable performance to Fishr, with both methods consistently outperforming ERM. This result supports the performance guarantee in \Cref{cor:hessian_alignment_non_irm} and validates our unified framework. Please refer to \Cref{app:per_data_result} for per-dataset and training-domain validation performance. 
\begin{figure}[t]
    \centering
    %
    \includegraphics[width=\linewidth]{img/fig_hess_acc_uai.pdf}
    \vspace{-2.5em}
    \caption{Hessian Penalty and worst-case accuracy on CelebA. Both curves represent the mean values, with shaded areas indicating $\pm$ one standard deviation over five runs.}
    \label{fig:hess_acc}
    \vspace{-1em}
\end{figure}
\vspace{-1em}
\subsection{Qualitative Results}\label{sec:exp_qual}
\vspace{-1em}
\textbf{Effect of Hessian Matching} 
%
%
We analyze CMA’s training progression and its impact on worst-group performance by plotting the Hessian loss:  
\begin{equation*}
    \frac{\beta}{K} \sum_{i = 1}^K \|\mathbf{H}_{\mu_i}(\boldsymbol{\theta}) - \overline{\mathbf{H}(\boldsymbol{\theta})}\|_F^2
\end{equation*}
for linear probing on the CelebA dataset, with the same hyperparameters as those reported for accuracy in \Cref{tab:real-datasets-transposed} ($\alpha = 5000$, $\beta = 100$, $\text{penalty annealing iterations} = 4000$).
As shown in \Cref{fig:hess_acc}, near step 4000, when the gradient and Hessian matching terms take effect, there is a sharp drop in Hessian penalty, accompanied by a noticeable increase in worst-case accuracy, aligning with our theory that aligning Hessians across domains improves worst-case performance.

\begin{figure}[t]
    \centering
    \includegraphics[width=0.9\linewidth]{img/VLCS_comparison.pdf}
    \caption{Comparison of first and second-moment differences across test domains for the VLCS dataset. The plots show the progression of moment differences over training steps for ERM, Fishr, and CMA. ERM fails to align the feature moments while CMA achieves the most effective alignment. The shaded regions represent one standard deviation above and below the mean across test domains.}
    \label{fig:cmnist_mom}
\vspace{-1em}
\end{figure}

\textbf{Feature Moment Matching}~~As discussed in \Cref{sec:unif_feature}, we illustrate the effect of CMA in matching the moments of features across domains. \Cref{fig:cmnist_mom} presents the moment differences between domains on VLCS dataset, where we average over all test domains. While ERM shows significant discrepancies in feature moments between domains, both Fishr and CMA successfully reduce these differences. Notably, CMA is more effective in reducing both first and second-moment disparities.
\vspace{-1em}
\subsection{Runtime and Memory Comparison}\label{app:time_mem}
\vspace{-1em}
We report the average time per step (in seconds) and memory usage (in GB) for each (algorithm, dataset) pair in  \Cref{tab:wallclock} and \Cref{tab:mem}. It is important to note that, in addition to the algorithms' efficiency, the wall-clock time also depends on the hardware status at the time of training. We include additional comparisons of two versions of CMA, HGP, and Hutchinson algorithms, where ``CMA (Speed)'' uses the time-efficient Hessian computation, while ``CMA (Memory)'' uses the memory-efficient Hessian computation.


Among the methods compared, only CMA and Hutchinson compute full Hessian matrices. While CMA is inherently slower than Fishr, CORAL, and HGP, which rely on diagonal approximations of the Hessian, it remains more time-efficient than Hutchinson's.


To highlight the scalability of ``CMA (Memory)'', we run small-scale experiments on OfficeHome, a dataset with 65 classes. In this setting, ``CMA (Speed)'' requires more than 75\,GB of memory and could not run on a single GPU, whereas ``CMA (Memory)'' completed successfully with peak usage under 13.7\,GB.

%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%


%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%


\begin{table*}[ht]
\centering
\caption{Wall-clock time across datasets (in seconds). Algorithms are grouped by the type of moment matching. For each dataset, we bold the most time-efficient algorithm within each category.}
\resizebox{\textwidth}{!}{%
\begin{tabular}{lcccccc}
\toprule
\textbf{Algorithm} & \textbf{ColoredMNIST (2)} & \textbf{RotatedMNIST (10)} & \textbf{VLCS (5)} & \textbf{PACS (7)} & \textbf{TerraIncognita (10)} & \textbf{OfficeHome (65)} \\
\midrule
\multicolumn{7}{l}{\textit{No Moment Matching}} \\
\hspace{1em}ERM          & 0.0278 & 0.0403 & 0.4019 & 0.3620 & 0.4216 & 0.4064 \\

\midrule
\multicolumn{7}{l}{\textit{Approximate Second-Order}} \\
\hspace{1em}CORAL        & \textbf{0.0457} & \textbf{0.1003} & 0.6241 & \textbf{0.5244} & 0.7697 & 0.5279 \\
\hspace{1em}Fishr        & 0.0925 & 0.1331 & 0.7472 & 0.6757 & 0.6057 & 0.6600 \\
\hspace{1em}HGP          & 0.0657 & 0.1292 & \textbf{0.6048} & 0.6729 & \textbf{0.6045} & \textbf{0.4977} \\

\midrule
\multicolumn{7}{l}{\textit{Exact Second-Order}} \\
\hspace{1em}Hutchinson   & 4.1663 & 9.7935 & 7.7604 & 7.3284 & 7.7270 & 7.8446 \\
\hspace{1em}CMA (Speed)  & \textbf{0.0676} & \textbf{0.1326} & \textbf{0.7354} & \textbf{0.7266} & \textbf{0.7421} & --     \\
\hspace{1em}CMA (Memory) & 0.1226 & 0.4723 & 0.8874 & 1.1699 & 1.0685 & \textbf{0.8495} \\
\bottomrule
\end{tabular}\label{tab:wallclock}
}
\end{table*}


\begin{table*}[ht]
\centering
\caption{Memory usage across datasets (in GB). For each dataset, we bold the most memory-efficient algorithm within each category.}
\resizebox{\textwidth}{!}{%
\begin{tabular}{lcccccc}
\toprule
\textbf{Algorithm} & \textbf{ColoredMNIST (2)} & \textbf{RotatedMNIST (10)} & \textbf{VLCS (5)} & \textbf{PACS (7)} & \textbf{TerraIncognita (10)} & \textbf{OfficeHome (65)} \\
\midrule
\multicolumn{7}{l}{\textit{No Moment Matching}} \\
\hspace{1em}ERM          & 0.1550 & 0.3728 & 6.8865 & 6.8865 & 6.8865 & 6.8868 \\

\midrule
\multicolumn{7}{l}{\textit{Approximate Second-Order}} \\
\hspace{1em}CORAL        & \textbf{0.1391} & 0.3190 & 6.7093 & 6.7093 & 6.7093 & 6.7097 \\
\hspace{1em}Fishr        & 0.3192 & 0.7936 & 14.1433 & 14.1436 & 14.1441 & 14.1522 \\
\hspace{1em}HGP          & 0.1477 & \textbf{0.3099} & \textbf{5.6835} & \textbf{5.6835} & \textbf{5.6836} & \textbf{5.6843} \\
\midrule
\multicolumn{7}{l}{\textit{Exact Second-Order}} \\
\hspace{1em}Hutchinson   & \textbf{0.1502} & \textbf{0.3496} & \textbf{5.7047} & \textbf{5.7125} & \textbf{5.7284} & \textbf{6.0029} \\
\hspace{1em}CMA (Speed)  & 0.2867 & 1.4537 & 13.9511 & 14.8391 & 16.7272 & \textasciitilde75 \\
\hspace{1em}CMA (Memory) & 0.3914 & 0.7776 & 13.6447 & 13.6448 & 13.6448 & 13.6474 \\
\bottomrule
\end{tabular}\label{tab:mem}
}
\end{table*}

%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%




%
%
%
%
%
%
%
%
%
%
%
%
%
%
%
%



%
%
%
%
%
%
%
%
%
%
%
%
%
