# seq-jepa

Autoregressive Predictive Learning of Invariant-Equivariant World Models

## Supplemental Data and Checkpoints

You can download the pre-computed STL-10 saliency maps alongside two model checkpoints (seq-JEPA trained on 3DIEBench conditioned on rotation with seq len 3 and seq-JEPA trained via predictive learning across saccades on STL-10 with seq len 4) from the following anonymous Zenodo folder:

https://zenodo.org/records/15492927?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6ImEyMDlkZjkxLWNhZjItNGI4MC05ZWIzLWY3MzE3YmE3ZWVjNiIsImRhdGEiOnt9LCJyYW5kb20iOiJkMTI2NjE2MDg0NGUzY2M5OTM4NzM0YWE5ODQ1OTJkNCJ9.g5dg8j38x9vefp-0r7xG3fAAWTkn0373p9BXRoXG33Ybnx-FHAvkqkac9aepu-PT0v4CWlp5p1QZx6u0a0ZemQ

## Installation

```bash
pip install -r requirements.txt
```

## Usage

### Training seq-JEPA on 3DIEBench

Train seq-JEPA on 3DIEBench with training sequence length of 3 conditioned on rotation. 

**Prerequisites:** Download 3DIEBench from https://dl.fbaipublicfiles.com/SIE/3DIEBench.tar.gz and untar it.

**Training command:**
```bash
python src/main_aug.py --data-root "3DIEBench_DATA_ROOT_FOLDER" --wandb --seed 42 --seq-len 3 --model "seqjepa" --run-id "seqjepa-3db-seq3tr" --latent-type "rot" --eval-type "rot" --num-workers 8 --ema --ema-decay 0.996 --num-heads 4 --num-enc-layers 3 --pred-hidden 1024 --act-projdim 128 --act-cond 1 --learn-act-emb 1 --backbone "resnet18" --dataset "3diebench" --img-size 128 --pred-hidden 1024 --epochs 1000 --save-freq 50 --batch-size 512 --lr 0.0004 --weight-decay 0.001 --optimizer "AdamW" --warmup 20 --scheduler
```

**Evaluation command:**
```bash
python src/main_aug.py --data-root "3DIEBench_DATA_ROOT_FOLDER" --is-eval --load-path "Path to checkpoint for example: seqjepa-3diebench-ckpt_epoch_1000.pth" --wandb --seed 42 --seq-len 3 --model "seqjepa" --run-id "eval-seqjepa-3db-seq3tr" --latent-type "rot" --num-workers 8 --ema --ema-decay 0.996 --num-heads 4 --num-enc-layers 3 --pred-hidden 1024 --act-projdim 128 --act-cond 1 --learn-act-emb 1 --backbone "resnet18" --dataset "3diebench" --img-size 128 --epochs 300 --save-freq 50 --batch-size 256 --lr 0.001 --weight-decay 0.0 --optimizer "Adam"
```

### Training seq-JEPA on STL-10

Train seq-JEPA on STL-10 via predictive learning across saccades with training sequence length of 4 conditioned on fixation position.

**Prerequisites:** Download STL-10 and its pre-computed saliency maps from the aforementioned Zenodo folder.

**Training command:**
```bash
python src/main_pls.py --data-path-sal 'PATH_TO_UNZIPPED_SAL_MAPS' --data-path-img "ORIG_STL-10_DATASET_PATH" --wandb --epochs 2000 --num-saccades 5 --run-id "seqjepa-stl10_pls-seq4tr" --num-workers 8 --ema --ema-decay 0.996 --num-heads 4 --num-enc-layers 3 --act-projdim 128 --use-sal 1 --ior 1 --act-cond 1 --learn-act-emb 1 --backbone "resnet18" --cifar-resnet --dataset "stl10" --img-size 96 --fovea-size 32 --save-freq 50 --seed 42 --batch-size 512 --lr 0.0004 --weight-decay 0.001 --optimizer "AdamW" --warmup 20 --scheduler
```

**Evaluation command:**
```bash
python src/main_pls.py --is-eval --load-path "Path to checkpoint for example: seqjepa-stl10_pls-ckpt_epoch_2000.pth" --data-path-sal 'PATH_TO_UNZIPPED_SAL_MAPS' --data-path-img "ORIG_STL-10_DATASET_PATH" --wandb --epochs 300 --gpu-id 0 --num-saccades 5 --run-id "eval-seqjepa-stl10_pls-seq4tr" --num-workers 8 --ema --ema-decay 0.996 --act-projdim 128 --use-sal 1 --ior 1 --act-cond 1 --learn-act-emb 1 --backbone "resnet18" --cifar-resnet --dataset "stl10" --img-size 96 --fovea-size 32 --seed 42 --batch-size 256 --num-heads 4 --num-enc-layers 3 --lr 0.001 --weight-decay 0.0 --optimizer "Adam" --warmup 0
```