## Reference code for "Leveraging Color Channel Independence for \\ Improved Unsupervised Object Detection"

The following code contains a reference implementation for using Slot Attention models with different color spaces. One of the main strengths of our approach is that it requires only a small change of architecture. We simply extended the codebase of the original Slot Attention paper, which can be found at "https://github.com/google-research/google-research/tree/master/slot_attention". Please follow the installation steps described there.

## Data Preparation
The following code runs experiments on the Clevrtex dataset. To download the data, execute the steps described at "https://github.com/google-research/google-research/tree/master/invariant_slot_attention/datasets/clevrtex". The tensorflow paths must be adjusted in the files "data.py" and "data_inf.py".

## Model Training
Our experiments comprise training on multiple color spaces. To start the training process, simple execute the python files stored in the folder "object_discovery". We included training files for all combined rgb/hsv spaces for a cnn and resnet architecture. To start the training of a CNN model on the RGB-S space, simply execute:
python -m  clevrtex_2.object_discovery.cnn_rgbs --seed=3 --model_dir="my_checkpoint"

This will store the encoder and decoder checkpoints in the folder "my_checkpoint_enc" and "my_checkpoint_dec".

## Inference
To inspect the predicted segmentation outcomes and reconstructions, use the inference code. If you want to run the experiments on the Clevrtex test set with the previously trained RGBS model, simply run the command:

python -m  clevrtex_2.object_discovery.inference_rgb_rgb --seed=3 --model_dir="my_checkpoint"

This will load the model stored in my_checkpoint, and creates five numpy arrays:

all_gt_images.npy 

all_gt_masks.npy 

all_pred_images.npy 

all_pred_masks.npy 

all_slots.npy

Those numpy array contain the ground truth images, the ground truth object masks, the reconstruction, and the predicted object masks, and finally the slot representations.

To further test the segmentation performance and MSE, simply execute

python3 evaluate.py

,which will load the numpy arrays and calculate the scores. We provide model checkpoints for the CNN models in the folder "checkpoints".