## Content

1. [Installation](#installation)
2. [Overview](#overview)
3. [Dataset & Model](#dataset)
4. [Commands](#commands)
   
## Installation
Follow this [documentation](/SlotFormer/docs/install.md), please install anaconda virtualenv on your system.

Code is tested on torch 2.3.0 and cuda 11.8 installed with below command instead of what's written in [documentation](/SlotFormer/docs/install.md).

```
conda create -n slotformer python=3.9 -y
conda activate slotformer
pip install torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
git clone https://github.com/Wuziyi616/nerv.git
cd nerv
git checkout v0.1.0
pip install -e . # install nerv
pip install pycocotools scikit-image lpips

pip install phyre==0.2.2  # please use the v0.2.2, since the task split might slightly differs between versions
cd .. # move to the top-level folder of the code
pip install -e . # install slotformer
cd slotformer/BC/datasets/language-table
pip install -e . # install language-table
pip install einops==0.8.0  
pip install transformers==4.25.1 sentence-transformers==2.2.2
pip install torchmetrics==0.10.0

```

## Overview
This code repo contains LSlotFormer based on [SlotFormer](https://github.com/pairlab/SlotFormer) and action decoder.

## Dataset
After package installation, install [language table dataset](https://github.com/google-research/language-table).

Run `python ./data/langtable.py --dir [DIR]` to create language table 4 blocks simulation data in directory. 

Check arguments in `./data_utils/langtable.py` to collect data with other settings.

For original datasets used in [SlotFormer](https://github.com/pairlab/SlotFormer), check these [benchmark](/SlotFormer/docs/benchmark.md) & [data](/SlotFormer/docs/data.md) documentations

## Commands
To train slot SAVi, use following command:
```
python scripts/train.py --task base_slots \
       --params slotformer/base_slots/configs/[PARAM_FILE].py \
       --fp16 --ddp --cudnn

```

To extract slots using trained model, use following command:
```
python slotformer/base_slots/extract_slots.py \
    --params slotformer/base_slots/configs/[PARAM_FILE].py \
    --weight checkpoint/[PARAM_FILE]/models/epoch/[CKPT_FILE].pth \
    --save_path [SLOT_FILE].pkl
```

To train lslotformer model based on loss computed between slots, use following command:
```
python scripts/train.py --task video_prediction \
       --params slotformer/video_prediction/configs/[PARAM_FILE].py \
       --fp16 --ddp --cudnn
```

To train action decoder, use following command:
```
python scripts/train.py --task BC \
       --params slotformer/BC/configs/[PARAM_FILE].py \
       --fp16 --ddp --cudnn
       
```