Adversarial Representation Learning for Canonical Correlation Analysis

ICLR 2023 submission

adCCA is a deep learning approach that learns maximally correlated latent representations from multimodal data. This implementation provides source code, simulated data generation tools and demonstrations.

Package dependency

The software was tested under the following package version

Run adCCA

adCCA requires three steps to learn canonical representations from multimodal data.

Step 1: construct adCCA object

mv = adCCA(x_dim=x_dim, y_dim=y_dim, z_dim=z_dim, device=device)

where x_dim and y_dim are input feature dimensions from two modalities, respectively; z_dim is the desired dimension for latent representation.

Step 2: training adCCA with batch data loader

losses_ax, losses_gx, losses_d = mv.fit(train_dl, 
                                        lr=1e-3,
                                        epochs_ae=500,
                                        epochs_inner=500,
                                        epoch_ad=3)
where train_dl is a standard torch data loader to feed data in batches; lr is the learning rate; epochs_ae, epochs_inner and epochs_inner are epochs of initial autoencoder training, within x-step/y-step training and iterative steps between x-step and y-step.

Step 3: inference final representations on trained model.

latent_x, latent_y = mv.inference(feature_x.to(device), feature_y.to(device))
where feature_x and feature_y are two modal data used to infer corresponding representations.

Simulation tool and demonstration

In subdirectory ./simulation/multimodal_simulation.py provided the function to generate multimodal simulation data with known sample classes.

demo.ipynb provides a demonstration of 1) generation simulated data; 2) representation learning using adCCA from simulated data and 3) visualization of learned representation.