# EQUALS: An Audio-Visual LLM with One-Stage Question-Guided Alignment and Flexible Fusion

![alt text](assets/framework.png)

## Model Weights

Please ensure the following pretrained models are available in the specified local paths:

1. **Whisper**: `./models/Whisper`
2. **BEATs**: `./models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt`
3. **Vicuna**: `./models/vicuna-7b-v1.5`
4. **CLIP-L**: `./models/clip-vit-large-patch14`
5. **InternVideo2**: `./models/InternVideo2-Stage2_1B-224p-f4/InternVideo2-stage2_1b-224p-f4.pt`

## Environment Check


conda activate ${env_name}
bash test.sh


## Training and Evaluation

### Command Line Example


deepspeed --num_gpus=2 train.py             
          --audio_encoder whisper_beats             
          --user_id name-run             
          --job_id v12345             
          --options optimizer.type=AdamW                       
              gradient_clipping=3.0                       
              gradient_accumulation_steps=4                       
              run.max_epoch=1                       
              run.updates_per_valid=250                       
              run.optims.warmup_step=600                       
              run.optims.warmup_start_lr=1.0e-6                       
              run.optims.init_lr=3.0e-5                       
              run.optims.min_lr=3.0e-6                       
              run.optims.weight_decay=0.01                       
              llm.use_flash_attention=true 


### Arguments

- `audio_encoder`: Type of audio encoder. Default: `"whisper_beats"`.
- `peft_ckpt`: Optional. Path to PEFT-trained checkpoint directory.
- `eval_only`: Optional flag. If specified, evaluation is performed using `datasets.test_ann_path` from the config file.
- `user_id`: User-defined identifier.
- `job_id`: Job identifier.

## `config.yaml` Parameters

### LLM & PEFT

- `llm`: Parameters related to the language model.
- `peft`: Parameter-efficient fine-tuning configuration.

### Pooling

- `tokenizer_name_or_path`: Tokenizer used to compute text embeddings.
- `model_name_or_path`: Text encoder model used for alignment.

#### Audio Alignment Pooler

- `audio_projection_dim`: Projection dimension for computing audio-text similarity.
- `llama_embeds_dim`: Output feature dimension.
- `kernel`: Pooling kernel size.
- `stride`: Pooling stride.
- `global_act`: Activation type on similarity matrix. Options: `softmax`, `sigmoid`, `no_operation`.
- `pooling_temperature`: Smoothing parameter for pooling.

#### Visual Alignment Pooler

- Same structure as audio alignment pooler, with:
  - `visual_projection_dim`
  - `llama_embeds_dim`
  - `kernel`, `stride`
  - `global_act`, `pooling_temperature`

### Optimal Transport (OT) Loss

#### Audio-Video OT

- `OT_AV.coeff_after`: Loss weight after pooling.

#### Audio-Text OT

- `OT_AT.coeff`: Loss weight.
- `OT_AT.use_text_mask`: Whether to mask padding tokens.

#### Video-Text OT

- `OT_VT.coeff`: Loss weight.
- `OT_VT.use_text_mask`: Whether to mask padding tokens.

### Encoders

- `audio_encoders`: Configuration for audio backbone encoders.
- `video_encoders`: Configuration for video backbone encoders.

### Connectors

- `connectors`: Feature projection modules between modalities and the language model.

### Datasets

- `datasets`: Data paths, preprocessing, annotations, and batch sampling strategy.

### Runner

- `run.log_main_loss`: Whether to log losses only on the main process (recommended `True` for speed).
- `run.output_dir`: Output directory for checkpoints and logs.

### Module Freezing

- `freeze_modules`: List of modules to freeze during training. Example: `["visual_pooler", "audio_pooler"]`

## `ds_config.yaml` Parameters

- `gradient_accumulation_steps`: Number of forward passes before optimizer step.
- `gradient_clipping`: Maximum gradient norm.
- `train_micro_batch_size_per_gpu`: Micro batch size per GPU. Recommended: `2`.
- `optimizer`: Optimizer configuration. See [DeepSpeed docs](https://deepspeed.readthedocs.io/en/latest/optimizers.html)

## Example Commands

### Training
deepspeed --num_gpus=2 train.py             
          --audio_encoder whisper_beats             
          --user_id htw-train             
          --job_id v0
          --options optimizer.type=AdamW                         
              gradient_clipping=3.0                         
              gradient_accumulation_steps=4                         
              run.max_epoch=1                         
              run.updates_per_valid=1                         
              run.optims.warmup_step=600                         
              run.optims.warmup_start_lr=1.0e-6                         
              run.optims.init_lr=3.0e-5                         
              run.optims.min_lr=3.0e-6                         
              run.optims.weight_decay=0.01                         
              llm.use_flash_attention=true


### Evaluation


deepspeed --num_gpus=2 train.py             
          --audio_encoder whisper_beats             
          --user_id htw-eval             
          --job_id v0             
          --eval_only             
          --peft_ckpt ../checkpoints/global_step_1             
          --options llm.use_flash_attention=true


## Heatmap Visualization


python attn_map.py 
      --json_path ./MUSIC-AVQA/annotations/avllm_json/avqa-test_real.json                    
      --peft_ckpt path/to/peft/ckpt                    
      --save_dir path/to/save/directory                    
      --samples 100                    
      --alpha 0.5                    
      [--save_ori_img]                    
      [--save_heat_map]                    
      --options pooling.visual_align_pooler.visual_projection_dim=1408


### Visualization Parameters

- `json_path`: Path to AVQA annotation JSON file.
- `peft_ckpt`: Directory containing trained weights, must include `visual_pooler.pt`.
- `save_dir`: Directory for saving visual outputs:
  - `heatmap/`: Raw heatmaps (if `--save_heat_map` specified).
  - `blend/`: Blended overlay images.
  - `img/`: Original images (if `--save_ori_img` specified).
- `samples`: Number of examples to visualize.
- `alpha`: Heatmap opacity in blended images (range: 0–1).
- `options`: Must match model configuration used during training.

## Alignment Pooling

![alt text](assets/question_guided_alingment_and_pooling.png)