
# Factor-Wise Homogeneity of Slot-Attention for Continual Object-Centric Learning

This repository is the official implementation of Factor-Wise Homogeneity of Slot-Attention for Continual Object-Centric Learning. 


## Requirements

To install requirements:

```setup
pip install -r requirements.txt
```

## Continual-Object Centric Learning (C-OCL) benckmark

- C-Tetrominoes
- C-CLEVR

In the supplymentory (.zip) file, we provide subset of our Continual-Object Centric Learning benckmark. We provide datasets for Single Step addition of Two classes (SST). 

For C-Tetrominoes, we provide codes for dataset generation. You can train the model with C-Tetrominoes without any modiciation of the provied codes. 

For C-CLEVR, we provide codes for data loading and pre-processing. C-CLEVR requires additional rendering process using Blender. In the supplymentory (.zip) file, we only include subset of evaluation samples of SST 0-th task, due to the size of the dataset. We sampled 100 images of 5000 imaged for SST task 0 evaluation dataset.
- C-CLEVR SST Task0: subset_C_CLEVR_shape_T0_test.h5.h5 (sampled version)

Check `./tasks_configs` for configurations for each dataset and change `dataset_path` to "path to .h5 file".


## Training

To train base Slot Attention in the paper, run provided shell scripts:

```train
bash ./scripts/run_base.sh
```

```bash
EP=0
LR=0.0004
SLOT=4
CONT=base
MODEL=base
CONFIG=tetrominos_shape_5-2-2
torchrun --nproc_per_node=4 main.py         \
--project 'Slot Attention'                  \
--amp                                       \
--dataset continual_tetrominoes             \
--num_task 2                                \
--resolution 64 64                          \
--output ./results                          \
--arch ${MODEL}                             \
--continual_arch ${CONT}                    \
--use_fp16                                  \
--batch_size 16                             \
--n_samples 10                              \
--sample_interval 10                        \
--lr ${LR}                                  \
--num_slots ${SLOT}                         \
--num_iterations 3                          \
--seed 43                                   \
--task_config ./task_configs/${CONFIG}.yaml \
--save_weights                          
```

- `--arch`: model configuration (ex "base" for Slot Attention)
- `--continual_arch`: model configuration (ex "dpr" for DPR)
- `--dataset`: training dataset configuration
- `--n_samples`: number of reconstruction samples to recode
- `--sample_interval`: training interval to recode reconstruction samples
- `--task_config`: path to training task (ex SST) configuration

To train the Slot Attention with proposed Decoder only Post Replay (DPR) in the paper, run provided shell scripts (change `CONT` to "dpr"):

```train
bash ./scripts/run_dpr.sh
```

```bash
...
CONT=dpr
MODEL=base
CONFIG=tetrominos_shape_5-2-2
torchrun --nproc_per_node=4 main.py         \
...                         
```
- `--replay_size`: number of replay buffer size
- `--replay_epochs`: number of replay training epochs


## Evaluation

To evaluate my model using ARI, MSE, mSC, SC, add `--eval_metrics` in the shell scripts and run:

```eval
bash ./scripts/run_base.sh
```

## Results

Our model achieves the following performance on:
- C-Terominoes(E0/T0)

| Model name         | ARI             | MSE            | mSC            | SC             |
| ------------------ |---------------- | -------------- | -------------- | -------------- |
| Slot Attention     |     0.9992      |  0.00004805    |      0.9192    |    0.9192      |
| Slot Attention DPR |     0.9991      |   0.00006119   |      0.9192    |    0.9192     |

- C-Terominoes(E0/T1)

| Model name         | ARI             | MSE            | mSC            | SC             |
| ------------------ |---------------- | -------------- | -------------- | -------------- |
| Slot Attention     |     0.4036      |      0.0439    |      0.2766    |    0.2766      |
| Slot Attention DPR |     0.9918      |   0.0006953    |      0.8082    |      0.8082    |

