# Towards a Unified Framework of Contrastive Learning for Disentangled Representations [NeurIPS 2023]
Official code to reproduce the results and data presented in the paper [Towards a Unified Framework of Contrastive Learning for Disentangled Representations].

This repo is based on the code of the paper [Contrastive Learning Inverts the Data Generating Process](https://brendel-group.github.io/

## Experiments 
To reproduce the disentanglement results for the MLP mixing, use the [main_mlp.py](main_mlp.py) script. For the experiments on KITTI Masks use the [main_kitti.py](main_kitti.py) script. For those on 3DIdent, use [main_3dident.py](main_3dident.py).

### MLP Mixing

```bash
> python main_mlp.py --help
usage: main_mlp.py [-h] [--gpu GPU] [--sphere-r SPHERE_R] [--box-min BOX_MIN] [--box-max BOX_MAX] [--output-norm {None,learnable_box,fixed_box,learnable_sphere,fixed_sphere}] [--only-supervised] [--only-unsupervised ONLY_UNSUPERVISED]
                   [--more-unsupervised MORE_UNSUPERVISED] [--save-dir SAVE_DIR] [--num-eval-batches NUM_EVAL_BATCHES] [--seed SEED] [--act-fct ACT_FCT] [--c-param C_PARAM] [--m-param M_PARAM] [--tau TAU]
                   [--n-mixing-layer N_MIXING_LAYER] [--n N] [--space-type {box,sphere,unbounded,hollow_ball,cube_grid}] [--m-p M_P] [--c-p [C_P ...]] [--lr LR] [--p P] [--loss {ince,nce,nwj,scl,simclr}] [--encoder {mlp,res}]
                   [--margin-mode {first,second,both}] [--center CENTER] [--batch-size BATCH_SIZE] [--n-log-steps N_LOG_STEPS] [--n-steps N_STEPS] [--resume-training] [--early-stopping EARLY_STOPPING]

Disentanglement with Contrastive Learning - MLP Mixing

options:
  -h, --help            show this help message and exit
  --gpu GPU             used GPU if cuda is available and set to True
  --sphere-r SPHERE_R
  --box-min BOX_MIN     For box normalization only. Minimal value of box.
  --box-max BOX_MAX     For box normalization only. Maximal value of box.
  --output-norm {None,learnable_box,fixed_box,learnable_sphere,fixed_sphere}
                        Normalize output.
  --only-supervised     Only train supervised model.
  --only-unsupervised ONLY_UNSUPERVISED
                        Only train unsupervised model.
  --more-unsupervised MORE_UNSUPERVISED
                        How many more steps to do for unsupervised compared to supervised training.
  --save-dir SAVE_DIR
  --num-eval-batches NUM_EVAL_BATCHES
                        Number of batches to average evaluation performance at the end.
  --seed SEED
  --act-fct ACT_FCT     Activation function in mixing network g.
  --c-param C_PARAM     Concentration parameter of the conditional distribution.
  --m-param M_PARAM     Additional parameter for the marginal (only relevant if it is not uniform).
  --tau TAU
  --n-mixing-layer N_MIXING_LAYER
                        Number of layers in nonlinear mixing network g.
  --n N                 Dimensionality of the latents.
  --space-type {box,sphere,unbounded,hollow_ball,cube_grid}
  --m-p M_P             Type of ground-truth marginal distribution. p=0 means uniform; all other p values correspond to (projected) Lp Exponential
  --c-p [C_P ...]       Exponent(s) of ground-truth Lp Exponential distribution. Make sure that len(c-p) in [1, n]
  --lr LR
  --p P                 Exponent of the assumed model Lp Exponential distribution. p=-1 means the exponents are learnable parameters.
  --loss {ince,nce,nwj,scl,simclr}
                        Loss function to minimize (only used if p==-1)
  --encoder {mlp,res}   Encoder architecture
  --margin-mode {first,second,both}
                        Encoder architecture
  --center CENTER       Whether to add additional loss to center the representation. Be careful with space constraint!
  --batch-size BATCH_SIZE
  --n-log-steps N_LOG_STEPS
  --n-steps N_STEPS
  --resume-training
  --early-stopping EARLY_STOPPING
                        Stop early if disentanglement score is high enough.
```

### KITTI Masks

```bash
>python main_kitti.py --help
usage: main_kitti.py [-h] [--box-norm BOX_NORM] [--p P] [--experiment-dir EXPERIMENT_DIR] [--evaluate] [--specify SPECIFY] [--random-search] [--random-seeds] [--seed SEED] [--beta BETA] [--gamma GAMMA] [--rate-prior RATE_PRIOR]
                     [--data-distribution DATA_DISTRIBUTION] [--rate-data RATE_DATA] [--data-k DATA_K] [--betavae] [--search-beta] [--output-dir OUTPUT_DIR] [--log-dir LOG_DIR] [--ckpt-dir CKPT_DIR] [--max-iter MAX_ITER]
                     [--dataset DATASET] [--batch-size BATCH_SIZE] [--num-workers NUM_WORKERS] [--image-size IMAGE_SIZE] [--use-writer] [--z-dim Z_DIM] [--lr LR] [--beta1 BETA1] [--beta2 BETA2] [--ckpt-name CKPT_NAME]
                     [--log-step LOG_STEP] [--save-step SAVE_STEP] [--kitti-max-delta-t KITTI_MAX_DELTA_T] [--natural-discrete] [--verbose] [--cuda] [--gpu GPU] [--num_runs NUM_RUNS] [--loss {simclr,ince,nce,nwj,scl}] [--center CENTER]

Disentanglement with Contrastive Learning - KITTI Masks

options:
  -h, --help            show this help message and exit
  --box-norm BOX_NORM
  --p P
  --experiment-dir EXPERIMENT_DIR
                        specify path
  --evaluate            evaluate instead of train
  --specify SPECIFY     use argument to only compute a subset of metrics
  --random-search       whether to random search for params
  --random-seeds        whether to go over random seeds with UDR params
  --seed SEED           random seed
  --beta BETA           weight for kl to normal
  --gamma GAMMA         weight for kl to laplace
  --rate-prior RATE_PRIOR
                        rate (or inverse scale) for prior laplace (larger -> sparser).
  --data-distribution DATA_DISTRIBUTION
                        (laplace, uniform)
  --rate-data RATE_DATA
                        rate (or inverse scale) for data laplace (larger -> sparser). (-1 = rand).
  --data-k DATA_K       k for data uniform (-1 = rand).
  --betavae             whether to do standard betavae training (gamma=0)
  --search-beta         whether to do rand search over beta
  --output-dir OUTPUT_DIR
                        output directory
  --log-dir LOG_DIR     log directory
  --ckpt-dir CKPT_DIR   checkpoint directory
  --max-iter MAX_ITER   maximum training iteration
  --dataset DATASET     dataset name (dsprites, cars3d,smallnorb, shapes3d, mpi3d, kittimasks, natural
  --batch-size BATCH_SIZE
                        batch size
  --num-workers NUM_WORKERS
                        dataloader num_workers
  --image-size IMAGE_SIZE
                        image size. now only (64,64) is supported
  --use-writer          whether to use a log writer
  --z-dim Z_DIM         dimension of the representation z
  --lr LR               learning rate
  --beta1 BETA1         Adam optimizer beta1
  --beta2 BETA2         Adam optimizer beta2
  --ckpt-name CKPT_NAME
                        load previous checkpoint. insert checkpoint filename
  --log-step LOG_STEP   numer of iterations after which data is logged
  --save-step SAVE_STEP
                        number of iterations after which a checkpoint is saved
  --kitti-max-delta-t KITTI_MAX_DELTA_T
                        max t difference between frames sampled from kitti data loader.
  --natural-discrete    discretize natural sprites
  --verbose             for evaluation
  --cuda
  --gpu GPU             used GPU if cuda is available and set to True
  --num_runs NUM_RUNS   when searching over seeds, do 10
  --loss {simclr,ince,nce,nwj,scl}
                        Loss function to minimize
  --center CENTER       Whether to add additional loss to center the representation. Be careful with space constraints!
```

### 3DIdent

```bash
>python main_3dident.py --help
usage: main_3dident.py [-h] [--mode {supervised,unsupervised,test}] [--unsupervised-loss {l1,l2,l3,vmf,ince,nce,nwj,scl}] [--tau TAU] [--center CENTER] [--load-model LOAD_MODEL] [--save-model SAVE_MODEL] [--save-every SAVE_EVERY]
                       [--batch-size BATCH_SIZE] [--n-eval-samples N_EVAL_SAMPLES] [--lr LR] [--optimizer {adam,sgd}] [--iterations ITERATIONS] [--n-log-steps N_LOG_STEPS] [--no-cuda] [--position-only] [--rotation-and-color-only]
                       [--rotation-only] [--color-only] [--no-spotlight-position] [--no-spotlight-color] [--no-spotlight] [--non-periodic-rotation-and-color] [--dummy-mixing] [--identity-solution] [--identity-mixing-and-solution]
                       [--approximate-dataset-nn-search] --offline-dataset OFFLINE_DATASET [--faiss-omp-threads FAISS_OMP_THREADS] [--box-constraint {None,fix,learnable}] [--sphere-constraint {None,fix,learnable}] [--workers WORKERS]
                       [--supervised-loss {mse,r2}] [--non-periodical-conditional {l1,l2,l3}] [--p P] [--sigma SIGMA] [--encoder {rn18,rn50,rn101,rn151}]

Disentanglement with Contrastive Learning - 3DIdent

options:
  -h, --help            show this help message and exit
  --mode {supervised,unsupervised,test}
  --unsupervised-loss {l1,l2,l3,vmf,ince,nce,nwj,scl}
  --tau TAU
  --center CENTER       Whether to add additional loss to center the representation. Be careful with space constraint!
  --load-model LOAD_MODEL
                        Path from where to load the model
  --save-model SAVE_MODEL
                        Path where to save the model
  --save-every SAVE_EVERY
                        After how many steps to save the model (will always be saved at the end)
  --batch-size BATCH_SIZE
  --n-eval-samples N_EVAL_SAMPLES
  --lr LR
  --optimizer {adam,sgd}
  --iterations ITERATIONS
                        How long to train the model
  --n-log-steps N_LOG_STEPS
                        How often to calculate scores and print them
  --no-cuda
  --position-only
  --rotation-and-color-only
  --rotation-only
  --color-only
  --no-spotlight-position
  --no-spotlight-color
  --no-spotlight
  --non-periodic-rotation-and-color
  --dummy-mixing
  --identity-solution
  --identity-mixing-and-solution
  --approximate-dataset-nn-search
  --offline-dataset OFFLINE_DATASET
  --faiss-omp-threads FAISS_OMP_THREADS
  --box-constraint {None,fix,learnable}
  --sphere-constraint {None,fix,learnable}
  --workers WORKERS     Number of workers to use (0=#cpus)
  --supervised-loss {mse,r2}
  --non-periodical-conditional {l1,l2,l3}
  --p P                 Exponent of the assumed model Lp Exponential distribution. p=-1 means the exponents are learnable parameters.
  --sigma SIGMA         Sigma of the conditional distribution (for vMF: 1/kappa)
  --encoder {rn18,rn50,rn101,rn151}
```
