# Escaping Plato's Cave: JAM for Aligning Independently Trained Vision and Language Models

This repository is the official implementation of [Escaping Plato's Cave: JAM for Aligning Independently Trained Vision and Language Models]. 

## Requirements

We recommend using **Python 3.8+** for this project. 
To set up your environment, install the required packages using the following command:

```bash
python -m venv jam
source jam/bin/activate 
pip install -r requirements.txt
```

## Data Set-up

### Extracting Features for Sugarcrepe benchmark data 
First, download the sugarcrepe benchmark data: Please follow the instruction from https://github.com/RAIVNLab/sugar-crepe?tab=readme-ov-file (MIT license)

To extract language and vision features, use the `features_extraction.py` script. This script supports both "base" and "large" model scales and will extract text embeddings (from models Gemma2, OLMo2, LLaMA3) and vision features (from DINOv2) for all Sugarcrepe texts and images.

#### **Example Command**

```bash
python features_extraction.py \
  --scale base \
  --hf_token YOUR_HF_TOKEN \
  --json_dir /path/to/sugarcrepe_captions/ \
  --img_dir /path/to/sugarcrepe_images/
```

#### **Arguments**
- `--scale`: Model scale to use (`base` or `large`). Default: `base`.
  - **base**:
    - Language models: `google/gemma-2-2b`, `allenai/OLMo-2-1124-7B`, `meta-llama/Llama-3.2-1B`
    - Vision model: `dinov2_vitb14`
  - **large**:
    - Language models: `google/gemma-2-9b`, `allenai/OLMo-2-1124-13B`, `meta-llama/Llama-3.2-3B`
    - Vision model: `dinov2_vitl14`
- `--hf_token`: Your HuggingFace token for model access (**required**).
- `--json_dir`: Directory containing JSON caption files (from Sugarcrepe benchmark) and where text embedding outputs will be saved in `.pkl` format (**required**). Two pkl files will be saved: one for positive caption (i.e., match case), and the other for negative cpation.
- `--img_dir`: Directory containing your image files (from Sugarcrepe benchmark) for vision feature extraction and where vision outputs will be saved in `.pkl` format (**required**). Two pkl files will be saved: one only containing the output CLS token, and the other containing CLS tokens of all hidden layers. 


### Extracting Features for Random Data (Data in random_data_easynonmatch directory): Features data required in order to replicate our platonic alignment metrics experiment for Easy Non-match case. 
In the `random_data_easynonmatch directory`, we provide `random_text_gatsby.json`, `random_text_wiki.json` which are text extracted from open-source "The Great Gatsby" book and "Wikipedia" data). 
To prepare random white noise image data, run `random_data_img_gen.py`. This will create white noise images in the `random_data_easynonmatch/white_noise_images/` folder. 

Now that all the random data (to test for Easy Non-match case for platonic alignment) are prepared, run `features_extraction_random_data.py` script to extract language and vision features.

#### **Example Command**

```bash
python features_extraction_random_data.py \
  --hf_token YOUR_HF_TOKEN \
  --json_dir jamming_with_plato_repo/random_data_easynonmatch \
  --img_dir jamming_with_plato_repo/random_data_easynonmatch/white_noise_images
```

#### **Arguments**
- `--hf_token`: Your HuggingFace token for model access (**required**).
- `--json_dir`: Directory containing your random text JSON files (We provide `random_text_gatsby.json`, `random_text_wiki.json` which are text extracted from open-source "The Great Gatsby" book and "Wikipedia" data). Text embeddings for each text case will be saved in `.pkl` format. 
- `--img_dir`: Directory containing your image files (We provide 5000 randomly generated `white_noise_images`). Image embeddings for white noise images will be saved in `.pkl` format. 


## Platonic Alignment Metrics Experiment

The script `run_platonic_alignment_metrics.py` computes a suite of alignment metrics between language and vision features. This is useful for evaluating how well the representations from your language and vision models are aligned. We experiment with 3 data cases: 
- Match: Text and Image features that correspond to capturing the same reality 
- Easy Non-match: Alignment between Random text & Images used in Match case OR between Texts used in Match case & White noise iamges. The overarching idea is that the text and image data clearly capture different reality (i.e., easily verifiable case of non-matching reality)
- Hard Non-match: Alignment between Hard Negative text and Images used in Match case. The overarching idea is that the Hard Negative texts capture similar context as Positive text (i.e., Match text), but only differ by fine-grained detail (which is the core reason it's still considered "negative", but hard to discern.)
The supported metrics include CKA, CKNNA, SVCCA, CCA-linear-PCA, and CCA-kernel-PCA.

#### **Example Command**

To test Match case: 
```bash
python run_platonic_alignment_metrics.py \
  --text_files /path/to/text_features.pkl \ # (e.g., gemma2_replace_rel.pkl)
  --img_file /path/to/image_features.pkl \ # (e.g., dinov2_vitb14_backbone_output.pkl)
  --output_dir match_alignment_metrics
```

To test Easy Non-match case: 
```bash
python run_platonic_alignment_metrics.py \ 
  --text_files /path/to/text_features.pkl \ # (e.g., /random_data_easynonmatch/gemma2_random_text_gatsby.pkl)
  --img_file /path/to/image_features.pkl \ # (e.g., /random_data_easynonmatch/white_noise_images/dinov2_vitb14_backbone_output.pkl)
  --direct_img_features True
  --output_dir easynonmatch_results_alignment_metrics
```

To test Hard Non-match case: 
```bash
python run_platonic_alignment_metrics.py \
  --text_files /path/to/text_features.pkl \ # (e.g., gemma2_replace_rel_neg.pkl)
  --img_file /path/to/image_features.pkl \ # (e.g., dinov2_vitb14_backbone_output.pkl)
  --output_dir hardnonmatch_results_alignment_metrics
```

#### **Arguments**
- `--text_files`: One or more paths to text feature `.pkl` files (can be from different models or datasets).
- `--img_file`: Path to the image feature `.pkl` file.
- `--output_dir`: Directory to save the results CSV (default: `results`).
- `--text_agg`: Text feature aggregation method (`mean` or `max`, default: `mean`).
- `--direct_img_features`: If set, treat the image features as direct arrays (no filename matching).

#### **Output**
- Computes the following alignment metrics between the text and image features & saved as CSV in a specified output_directory:
  - **CKA** (Centered Kernel Alignment)
  - **CKNNA** (Cycle K-Nearest Neighbor Alignment)
  - **SVCCA** (Singular Vector Canonical Correlation Analysis)
  - **CCA-linear-PCA** (Canonical Correlation Analysis after linear PCA)
  - **CCA-kernel-PCA** (Canonical Correlation Analysis after kernel PCA)

## JAM Training & Evaluation 

### JAM Training with Final Output Layer Embeddings of Language & Vision Models: `train_output.py`

The script `train_output.py` provides a unified interface for training different variants of the joint autoencoder (AE) bridge model. You can select the training mode and configure all relevant hyperparameters and data paths via command-line arguments.

#### **Example Commands**

Train with spread loss:
```bash
python train_output.py --mode spread --image_path /path/to/image.pkl --positive_text_path /path/to/pos_text.pkl --negative_text_path /path/to/neg_text.pkl --csv_out results_spread.csv
```

Train with baseline contrastive (i.e., L_con) loss:
```bash
python train_output.py --mode baseline_con --image_path /path/to/image.pkl --positive_text_path /path/to/pos_text.pkl --negative_text_path /path/to/neg_text.pkl --csv_out results_baseline_con.csv
```

Train with baseline negative contrastive (i.e., L_negcon) loss:
```bash
python train_output.py --mode baseline_negcon --image_path /path/to/image.pkl --positive_text_path /path/to/pos_text.pkl --negative_text_path /path/to/neg_text.pkl --csv_out results_baseline_negcon.csv
```

#### **Modes**
- `spread`: Uses the spread loss 
- `baseline_con`: Uses the contrastive loss (align_type="lcon")
- `baseline_negcon`: Uses the negative contrastive loss (align_type="lconneg")

#### **Key Arguments**
- `--mode`: Training mode (`spread`, `baseline_con`, or `baseline_negcon`) 
- `--image_path`: Path to image embeddings `.pkl` file
- `--positive_text_path`: Path to positive text embeddings `.pkl` file 
- `--negative_text_path`: Path to negative text embeddings `.pkl` file 
- `--text_agg`: Text aggregation method (`mean` or `max`, default: `mean`)

- `--batch_size`: Batch size (default: 32)
- `--test_ratio`: Fraction of data for test split (default: 0.1)
- `--val_ratio`: Fraction of data for validation split (default: 0.2)
- `--shuffle_seed`: Random seed for shuffling (default: 55)

- `--latent_dim`: Latent dimension for autoencoders (default: 256)
- `--hidden_dim`: Hidden dimensions for autoencoders (default: (512, 512, 512))
- `--epochs`: Number of training epochs (default: 200)
- `--csv_out`: Output CSV file for results (default: ./output.csv)
- `--max_alpha`: Alpha value for spread loss (used in `spread` mode, default: 0.5)

#### **Output**
- Model checkpoints are saved as `best_val_recall.pt` during training.
- Training logs and metrics are tracked with Weights & Biases (wandb) & also saved to the specified CSV file.


### JAM Training with Hidden Layer Selection of Language & Vision Models: `train_hidden.py`

The script `train_hidden.py` trains JAM using features from specific hidden layers of the language and vision models. The training uses the spread loss with curriculum alpha scheduling. The purpose of this experiment is to explore the effect of intermediate representations on alignment performance and how much supervision (i.e., the alpha weight) is required to achieve the best alignment performance for the selected intermediate layers. 

#### **Example Command**

```bash
python train_hidden.py \
  --image_path /path/to/image.pkl \
  --positive_text_path /path/to/pos_text.pkl \
  --negative_text_path /path/to/neg_text.pkl \
  --text_layer 5 \
  --img_layer 5 \
  --csv_out results_hidden.csv
```

#### **Key Arguments**
- `--image_path`: Path to image embeddings `.pkl` file 
- `--positive_text_path`: Path to positive text embeddings `.pkl` file 
- `--negative_text_path`: Path to negative text embeddings `.pkl` file 
- `--text_agg`: Text aggregation method (`mean` or `max`, default: `mean`)
- `--text_layer`: Language-model layer to use (default: 24, `-1` = last layer)
- `--img_layer`: Vision-model layer to use (default: 9, `-1` = last layer)

- `--batch_size`: Batch size (default: 32)
- `--test_ratio`: Fraction of data for test split (default: 0.1)
- `--val_ratio`: Fraction of data for validation split (default: 0.2)
- `--shuffle_seed`: Random seed for shuffling (default: 55)

- `--latent_dim`: Latent dimension for autoencoders (default: 256)
- `--hidden_dim`: Hidden dimensions for autoencoders (default: (512, 512, 512))
- `--epochs`: Number of training epochs (default: 200)
- `--max_alpha`: Final alpha value after ramp (default: 0.7)
- `--warmup_frac`: Fraction of epochs with alpha=0 (default: 0.0)
- `--ramp_frac`: Fraction of epochs for linear ramp (default: 1.0)
- `--csv_out`: Output CSV file for results (default: ./hidden_spread_results.csv)

#### **Output**
- Model checkpoints are saved as `best_val_recall.pt` during training.
- Training logs and metrics are tracked with Weights & Biases (wandb) & also saved to the specified CSV file.


