# GRAM-M: General-Purpose Audio Representation Model-Mamba

This repository contains the implementation of GRAM-M, a state-of-the-art audio processing model trained on AudioSet with naturalistic scenes. It includes the listen-eval-kit, an extended fork of hear-eval-kit with enhanced functionality for sound localization tasks. The framework leverages PyTorch Lightning and Hydra with TensorBoard logging for comprehensive hyperparameter optimization.

## System Requirements

This repository has been validated with:
- Python 3.10
- PyTorch 2.1.2 (CUDA 11.8)

## Installation

### Training Environment

```bash
# Create and activate conda environment
conda create -n gram-m python=3.10 -y
conda activate gram-m
conda install nvidia/label/cuda-11.8.0::cuda-toolkit

# Install PyTorch with CUDA support
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118

# Install GRAM-M specific dependencies
pip install -r requirements.txt

# Install additional required packages
pip install git+https://github.com/Dao-AILab/causal-conv1d.git@v1.1.3.post1
pip install git+https://github.com/state-spaces/mamba.git@v1.1.3.post1
```

### Evaluation Environment

```bash
# Create and activate conda environment
conda create -n gram-m-eval python=3.10 -y
conda activate gram-m-eval
conda install nvidia/label/cuda-11.8.0::cuda-toolkit

# Install PyTorch with CUDA support
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118

# Install evaluation-specific dependencies
pip install -r requirements_eval.txt

# Install additional required packages
pip install git+https://github.com/Dao-AILab/causal-conv1d.git@v1.1.3.post1
pip install git+https://github.com/state-spaces/mamba.git@v1.1.3.post1
```

## Model Training

### GRAM-M-Time Model Training

```bash
python3 train.py data=audioset data.sr=32000 patching=time data.mask_patch=80 trainer.batch_size=32 trainer.steps=200000
```

### GRAM-M-Patch Model Training

```bash
python3 train.py data=audioset data.sr=32000 patching=frame data.mask_patch=100 trainer.batch_size=32 trainer.steps=200000
```

**Note:** Training requires approximately 72 hours on an A100 GPU. Models and TensorBoard logs are saved to the directory specified in `cfg.save_dir`.

## Downstream Evaluation

### HEAR Benchmark Evaluation

#### Prerequisites

1. **Dataset Preparation:**
   - Follow the instructions at https://hearbenchmark.com/hear-tasks.html to acquire the data
   - For convenience, download pre-processed 32000 Hz data directly from HEAR's [GCS bucket](gs://hear2021-archive/tasks/)
   - Extract all files to a designated directory (`$TASKS_DIR`)

2. **Pre-trained Model Weights:**
   - Download pre-trained weights from [repository link]
   - Set the path to the weights directory as `$MODEL_DIR`

#### Feature Extraction and Evaluation

```bash
cd listen-eval-kit

# Define environment variables
embeddings_dir=/path/to/save/embeddings
tasks_dir=$TASKS_DIR
task_name=dcase2016_task2-hear2021-full

# Set model parameters
weights=$MODEL_DIR
model_name=hear_configs.GRAMM #GRAMMTime if Time based model
strategy=raw
use_mwmae_decoder=true
in_channels=2
model_options="{\"strategy\": \"$strategy\",\"use_mwmae_decoder\": \"$use_mwmae_decoder\", \"in_channels\": $in_channels}"

# Extract features
python3 -m heareval.embeddings.runner "$model_name" --tasks-dir $tasks_dir --task "$task_name" --embeddings-dir $embeddings_dir --model-options "$model_options" --model $weights 

# Execute task evaluation
python3 -m heareval.predictions.runner $embeddings_dir/$model_name-strategy=$strategy-use-mwmae-decoder=$use_mwmae_decoder-in-channels=$in_channels/$task_name
```

### Nat-HEAR Benchmark Evaluation

#### Prerequisites

1. **Dataset Preparation:**
   - Access the Nat-HEAR dataset from [repository link]
   - Extract all files to a designated directory (`$NATURALISTIC_TASKS_DIR`)

2. **Pre-trained Model Weights:**
   - Download pre-trained weights from [repository link]
   - Set the path to the weights directory as `$MODEL_DIR`

#### Feature Extraction and Evaluation

```bash
cd listen-eval-kit

# Define environment variables
embeddings_dir=/path/to/save/embeddings
tasks_dir=$NATURALISTIC_TASKS_DIR
task_name=dcase2016_task2-hear2021-full

# Set model parameters
weights=$MODEL_DIR
model_name=hear_configs.GRAMM #GRAMMTime if Time based model
strategy=raw
use_mwmae_decoder=true
in_channels=2
model_options="{\"strategy\": \"$strategy\",\"use_mwmae_decoder\": \"$use_mwmae_decoder\", \"in_channels\": $in_channels}"

# Extract features
python3 -m heareval.embeddings.runner "$model_name" --tasks-dir $tasks_dir --task "$task_name" --embeddings-dir $embeddings_dir --model-options "$model_options" --model $weights 

# Execute task evaluation
python3 -m heareval.predictions.runner $embeddings_dir/$model_name-strategy=$strategy-use-mwmae-decoder=$use_mwmae_decoder-in-channels=$in_channels/$task_name
```

### Sound Localization on Nat-HEAR Benchmark

#### Prerequisites

1. **Dataset Preparation:**
   - Access the Nat-HEAR-localization dataset from [repository link]
   - Extract all files to a designated directory (`$LOCALIZATION_TASKS_DIR`)

2. **Pre-trained Model Weights:**
   - Download pre-trained weights from [repository link]
   - Set the path to the weights directory as `$MODEL_DIR`

#### Feature Extraction and Evaluation

```bash
cd listen-eval-kit

# Define environment variables
embeddings_dir=/path/to/save/embeddings
tasks_dir=$LOCALIZATION_TASKS_DIR
task_name=dcase2016_task2-hear2021-full

# Set model parameters
weights=$MODEL_DIR
model_name=hear_configs.GRAMM #GRAMMTime if Time based model
strategy=mean
use_mwmae_decoder=true
in_channels=2
model_options="{\"strategy\": \"$strategy\",\"use_mwmae_decoder\": \"$use_mwmae_decoder\", \"in_channels\": $in_channels}"

# Extract features
python3 -m heareval.embeddings.runner "$model_name" --tasks-dir $tasks_dir --task "$task_name" --embeddings-dir $embeddings_dir --model-options "$model_options" --model $weights 

# Execute localization evaluation
python3 -m heareval.predictions.runner $embeddings_dir/$model_name-strategy=$strategy-use-mwmae-decoder=$use_mwmae_decoder-in-channels=$in_channels/$task_name --localization cartesian-regression
```