# Beyond Instance-Level Alignment: Dual-Level Optimal Transport for Audio-Text Retrieval


This repository contains implementations and baseline comparisons for our proposed method **dart**, which introduces a feature-level optimal transport loss for improving audio-text retrieval performance.

## Structure Overview

```
.
├── dart/               # Our method: single-GPU version
├── dart_distributed/   # Our method: multi-GPU (distributed) version
├── vast/                # Baseline method: VAST
└── one-peace/           # Baseline method: One-PEACE
```

## Our Method

### `dart/` – DART (Single GPU)

This folder contains the single-GPU implementation of our method, **DART** (Feature-Level Optimal Transport). It incorporates a feature-level alignment loss based on entropic optimal transport to enhance audio-text retrieval performance.

We use training code from two open-source projects with minor modifications:

- [On Metric Learning for Audio-Text Cross-Modal Retrieval](https://github.com/XinhaoMei/audio-text_retrieval)
- [m-LTM-Audio-Text-Retrieval](https://github.com/v-manhlt3/m-LTM-Audio-Text-Retrieval)

### Dataset
Please follow the dataset setup instructions provided in the original [m-LTM-Audio-Text-Retrieval](https://github.com/v-manhlt3/m-LTM-Audio-Text-Retrieval) repository.  

### Training
The training config is in the setting folder `settings/m-ltm-settings.yaml`

Run experiments: `CUDA_VISIBLE_DEVICES=0 python train.py -n [exp_name] -c m-ltm-settings` 

Replace `[exp_name]` with your desired experiment name.

### Switching to Baseline Training
By default, the training script uses our method, FLOAT, which is implemented in:
```
from trainer.trainer_minibatch_Semi import train
```
If you wish to run the baseline method (Triplet loss, Constrastive loss, etc.), you should comment out the above line and uncomment the following one in `train.py`:
```
from trainer.trainer_minibatch_Baselines import train
```
Once switched, you can launch the experiment using the same command:
`CUDA_VISIBLE_DEVICES=0 python train.py -n [exp_name] -c m-ltm-settings`

### `float_distributed/` – FLOAT (Multi GPU / Distributed)
This folder contains the distributed (multi-GPU) implementation of our method **FLOAT**, designed for large-scale training with improved efficiency and scalability.

The overall structure and distributed training strategy are inspired by the [open_clip](https://github.com/mlfoundations/open_clip) framework.

#### Dataset
Download the dataset and place in the `data/`:
```
├── float_distributed/
│   ├── data/
│   │   ├── AudioCaps/
│   │   ├── Clotho
│   └── ...
```
#### Pretrained Models
We use `ResNet38` as the audio encoder and `bert-base-uncased` as the text encoder. Please download the pretrained weights and place them in the following directories:
```
├── float_distributed/
│   ├── pretrained_models/
│   │   ├── audio_encoder/
│   │   │   ├── ResNet38.pth
│   │   ├── text_encoder/
│   │   │   ├── bert-base-uncased/
│   │   │   └── ...
```
#### Launch Training
Use the following script to launch training:

```bash
sh run.sh audiocaps [exp_name]
# or
sh run.sh clotho [exp_name]
```
Replace `[exp_name]` with your desired experiment name.

The implementation of our FLOAT loss can be found in:
```
loss.py – class Entropic_OT_Loss
```

### `vast/` – Baseline: VAST

This directory contains files modified from the official implementation of  **VAST: A Vision-Audio-Subtitle-Text Omni-Modality Foundation Model and Dataset** [VAST GitHub Repository](https://github.com/TXH-mercury/VAST/tree/master)

We adapt this baseline to incorporate our method **FLOAT** by modifying the loss function components.

To run FLOAT within the VAST framework, you can:

#### Integration Instructions

- Copy or replace the files under the `vast/` directory in this repo into the corresponding locations in the official VAST codebase.
- All filenames are aligned with the original structure for seamless integration.

#### Key Modifications
- Our float loss is implemented in `model/vast.py`:
  - `forward_ref_float()` defines the computation of the FLOAT-based feature alignment.
  - `forward_ref_float2()` contains the full forward pass including the overall loss used for training.
- To enable FLOAT during training, modify the `forward()` function in `model/vast.py`:

Replace the original logic:
```python
if task.startswith('ret'):
    ret_dict = self.forward_ret(batch, task, compute_loss=compute_loss)
```
with either:
```
ret_dict = self.forward_ref_float(batch, task, compute_loss=compute_loss)
```
or
```
ret_dict = self.forward_ref_float2(batch, task, compute_loss=compute_loss)
```
In addition, since FLOAT uses the Sinkhorn algorithm during retrieval, we also modify the evaluation function (the `evaluate_ret()` in `evaluation/evaluation_mm.py`) accordingly.

#### Dataset
Following the [VAST](https://github.com/TXH-mercury/VAST/tree/master?tab=readme-ov-file#download--vast-models--and-captioners-for-labeling-your-own-data), download downstream datasets annotations for finetuning.

srcdata:
audiocaps download from [audiocaps](https://github.com/cdjkim/audiocaps/tree/master/dataset2.0).

clotho download from [clothov2](https://zenodo.org/records/4783391).

#### Pretrained Model
Download basic encoder's pretrained checkpoints from [vast](Download basic encoder's pretrained checkpoints)

Download vast's pretrained model from [vast](https://github.com/TXH-mercury/VAST/tree/master?tab=readme-ov-file#download--vast-models--and-captioners-for-labeling-your-own-data) for finetune.

#### Launch Training
```
cd vast
sh scripts/vast/finetune_ret.sh
```

### `one-peace/` – Baseline: One-PEACE
This directory contains files modified from the official implementation of  
**One-PEACE: One general representation model across modality and task**  
[VAST GitHub Repository](https://github.com/microsoft/One-PEACE)

We extend the One-PEACE framework by integrating our method **FLOAT** loss to enhance audio-text retrieval performance.

Similar to vast to -copy or replace the files under the `one-peace/` directory in this repo into the corresponding locations in the official one-peace codebase.

#### Download Pretrained Models
Download Pretrained Models ONE-PEACE from [One-Peace](https://github.com/OFA-Sys/ONE-PEACE?tab=readme-ov-file#model-card).

#### Download datasets
Following one-peace to download [audiocaps and clotho](https://github.com/OFA-Sys/ONE-PEACE/blob/main/datasets.md#audio) .

#### Fine-tuning with FLOAT

To apply FLOAT during fine-tuning, edit the configuration name in the script
`one-peace/run_scripts/audio_text_retrieval/finetune.sh` by change `config_name=finetune_audiocaps` or `config_name=finetune_clotho` depending on your target dataset.

Then launch training with:
```
cd one_peace/run_scripts/audio_text_retrieval/
sh finetune.sh [exp_name]
```

## Acknowledgments

https://github.com/mlfoundations/open_clip

https://github.com/v-manhlt3/m-LTM-Audio-Text-Retrieval

https://github.com/TXH-mercury/VAST

https://github.com/OFA-Sys/ONE-PEACE

https://github.com/XinhaoMei/audio-text_retrieval




