Cross-modal Representation Flattening for Multi-modal Domain Generalization

Published: 25 Sept 2024, Last Modified: 06 Nov 2024NeurIPS 2024 posterEveryoneRevisionsBibTeXCC BY 4.0
Keywords: Multimodal Learning, Domain Generalization, Sharpness-aware Minimization, Representation Flattening
Abstract: Multi-modal domain generalization (MMDG) requires that models trained on multi-modal source domains can generalize to unseen target distributions with the same modality set. Sharpness-aware minimization (SAM) is an effective technique for traditional uni-modal domain generalization (DG), however, with limited improvement in MMDG. In this paper, we identify that modality competition and discrepant uni-modal flatness are two main factors that restrict multi-modal generalization. To overcome these challenges, we propose to construct consistent flat loss regions and enhance knowledge exploitation for each modality via cross-modal knowledge transfer. Firstly, we turn to the optimization on representation-space loss landscapes instead of traditional parameter space, which allows us to build connections between modalities directly. Then, we introduce a novel method to flatten the high-loss region between minima from different modalities by interpolating mixed multi-modal representations. We implement this method by distilling and optimizing generalizable interpolated representations and assigning distinct weights for each modality considering their divergent generalization capabilities. Extensive experiments are performed on two benchmark datasets, EPIC-Kitchens and Human-Animal-Cartoon (HAC), with various modality combinations, demonstrating the effectiveness of our method under multi-source and single-source settings. Our code is open-sourced.
Supplementary Material: zip
Primary Area: Machine vision
Submission Number: 9260
Loading