# Task Conditioned Stochastic Subsampling

This repository contains the accompanying code for SSS. 

## Dependencies
```
torch==1.8.0+cu101
seaborn==0.11.1
tensorboard==2.5.0
torchvision==0.9.0+cu101
```

## Computational Resources
Training SSS on the CelebA classification task is set up to use 4 GPUs with ~12GB of memory each. This can be changed in the corresponding 
scripts and will require a modification of the batch size.

## Training

To train the 1-D function reconstruction task, run the following:

```train
bash scripts/function/train/XXXX/XXXX.sh
```
where XXXX is either random or sss.

To train the models used for ablation studies, run the following:
```train
bash scripts/function/train/ablation/XXXX.sh
```
where XXXX is one of autoregressive/candidate/random\_autoregressive

To train the model for CelebA image reconstruction, run the following:
```train
bash scripts/celeba/train/XXXX/XXXX.sh
```
where XXXX is either random or sss.

## Evaluation

We provide pretrained checkpoints for the 1-D function reconstruction task and the CelebA image reconstruction task for the random selector and SSS in
the checkpoints directory.

To evaluate the 1-D function models, run:
```eval
bash scripts/function/valid/XXXX/metric/XXXX.sh
```
where XXXX is either random or sss.
To visualize reconstructed functions using the provided checkpoints, run the following:
```eval
bash scripts/function/valid/XXXX/visualize/XXXX.sh
```
where XXXX is either random or sss.

For the CelebA experiments, run the following:
```eval
bash scripts/celeba/valid/XXXX/metric/XXXX.sh
bash scripts/celeba/valid/XXXX/visualize/XXXX.sh
```
where XXXX is either random or sss. The results from the visualizations can be found in the visualization folder.
