[![Python](https://img.shields.io/badge/python-3.9+-informational.svg)](https://www.python.org/downloads/release/python-3918/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=black)](https://pycqa.github.io/isort)
[![documentation](https://img.shields.io/badge/docs-mkdocs%20material-blue.svg?style=flat)](https://mkdocstrings.github.io)
[![wandb](https://img.shields.io/badge/tracking-wandb-blue)](https://wandb.ai/site)
[![dvc](https://img.shields.io/badge/data-dvc-9cf)](https://dvc.org)
[![Hydra](https://img.shields.io/badge/Config-Hydra-89b8cd)](https://hydra.cc)
[![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)

# DisCoNet

<p align="center">
  <img src="imgs/disconet.png" width="1000%" alt="Overview of DisCoNet's architecture." caption="Overview of DisCoNet's architecture.">
</p>

The official implementation of DisCoNet in PyTorch.

## Prerequisites

You will need:

- `python` (see `pyproject.toml` for full version)
- `Git`
- `Make`
- a `.secrets` file with the required secrets and credentials
- load environment variables from `.env`
- `CUDA >= 12.1`

## Installation

Clone this repository (using HTTPS also works)

    git clone --recursive <ssh link>
    cd disconet

Create the Conda environment

    conda env create -f environment.yml
    conda activate python3.9

### On Linux

And then setup all virtualenv using make file recipe

    (python3.9) $ make setup-all

You might be required to run the following command once to setup the automatic activation of the conda environment and the virtualenv:

    direnv allow

Feel free to edit the [`.envrc`](.envrc) file if you prefer to activate the environments manually.

### On Windows

You can setup the virtualenv by running the following commands:

    python -m venv .venv
    .venv/Scripts/Activate.ps1
    python -m pip install --upgrade pip setuptools
    pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
    python -m pip install -r requirements/requirements-win.txt


To run the code please remember to always activate both environments:

    conda activate python3.9
    .venv/Scripts/Activate.ps1

## OOD Benchmark

The evaluation of these models closely follows [OpenOOD's](https://github.com/jingkang50/openood) benchmark. Three types of OOD levels are defined: Near-OOD, which exhibits semantic shifts compared to ID datasets; Far-OOD, which encompasses both semantic and domain shifts; Covariate Shift OOD, which involves corruptions within the ID set. There are also four well-defined ID datasets:

1. **MNIST**
    - **Near-OOD**: FashionMNIST
    - **Far-OOD**: CIFAR-10, TinyImageNet, DTD, Places365
    - **Covariate Shift OOD**: MNIST(-C)

2. **CIFAR-10**
    - **Near-OOD**: CIFAR-100, TinyImageNet
    - **Far-OOD**: MNIST, SVHN, DTD, Places365
    - **Covariate Shift OOD**: CIFAR-10(-C)

3. **TinyImageNet**
    - **Near-OOD**: SSB-hard, NINCO
    - **Far-OOD**: iNaturalist, DTD, OpenImage-O
    - **Covariate Shift OOD**: TinyImageNet(-C)

4. **ImageNet-1K**
    - **Near-OOD**: SSB-hard, NINCO
    - **Far-OOD**: iNaturalist, DTD, OpenImage-O
    - **Covariate Shift OOD**: ImageNet(-C)

## Datasets Availability

MNIST, FashionMNIST, CIFAR-10, CIFAR-100 and SVHN are automatically downloaded by the PyTorch Dataloader that is used.

TinyImageNet must be downloaded from [Kaggle](https://www.kaggle.com/datasets/nikhilshingadiya/tinyimagenet200). After downloading it:
1. Unzip the Folder.
2. Rename it as `tinyimagenet`.
3. Move it to [`data/raw`](data/raw).

ImageNet-1K must be downloaded from [Kaggle](https://www.kaggle.com/c/imagenet-object-localization-challenge/data). After downloading it:
1. Unzip the Folder.
2. Rename the folder `ILSVRC/Data/CLS-LOC` as `imagenet`.
3. Move `imagenet` only to [`data/raw`](data/raw). The other folders can be deleted.
4. Run the [`process_imagenet.py`](src/disconet/process_imagenet.py) script. This operation takes a while due to de dataset size.

The remaining datasets can be downloaded using [`datasets_download.py`](src/disconet/datasets_download.py) by running the following commands:

    cd src/disconet
    python datasets_download.py [--imagenet]

**Note: Use the `--imagenet` flag if you want to download ImageNet-C.**

##  Models

In addition to DisCoNet, two other adversarial models can be found in this repository, serving as a baseline.

- DisCoNet [`Code`](src/disconet/models/DisCoNet.py)|[`Train Script`](src/disconet/train_disconet.py)|[`Eval Script`](src/disconet/eval_disconet.py)|[`Documentation`](docs/DisCoNet.md)
- DC-GAN [`Paper`](https://arxiv.org/abs/1511.06434)|[`Code`](src/disconet/models/DCGAN.py)|[`Train Script`](src/disconet/train_dcgan.py)|[`Eval Script`](src/disconet/eval_dcgan.py)|[`Documentation`](docs/DCGAN.md)
- PresGAN [`Paper`](https://arxiv.org/abs/1910.04302)|[`Code`](src/disconet/models/PresGAN.py)|[`Train Script`](src/disconet/train_presgan.py)|[`Eval Script`](src/disconet/eval_presgan.py)|[`Documentation`](docs/PresGAN.md)

### Train and Evaluate Models

The commands required to train and evaluate each of the models are provided in the documentation section: [`DisCoNet.md`](docs/DisCoNet.md), [`DCGAN.md`](docs/DCGAN.md), and [`PresGAN.md`](docs/PresGAN.md).

### Pre-trained Models

You can download the pre-trained DisCoNet checkpoints using this [`link`](https://drive.google.com/file/d/1PdJ0UNIPDkTGY-0U7F0wmWTRsYmyb09a/view?usp=sharing).

### Results

#### Covariate Shift OOD

| ID Dataset         |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| MNIST              |   99.9%         |      0.1%      |
| CIFAR-10           |   96.2%         |      11.4%     |
| TinyImageNet       |   99.7%         |      1.5%      |
| ImageNet-1K        |   98.9%         |      3.8%      |

#### Near-OOD

1. **MNIST**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| FashionMNIST       |   100.0%        |      0.0%      |

2. **CIFAR-10**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| CIFAR-100          |    75.0%        |      75.4%     |
| TinyImageNet       |    91.4%        |      37.6%     |

3. **TinyImageNet**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| SSB-hard           |   100.0%        |      0.0%      |
| NINCO              |   100.0%        |      0.0%      |

4. **ImageNet-1K**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| SSB-hard           |   99.0%         |      0.0%      |
| NINCO              |   99.7%         |      0.1%      |

#### Far-OOD

1. **MNIST**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| CIFAR-10           |   100.0%        |      0.0%      |
| TinyImageNet       |   100.0%        |      0.0%      |
| DTD                |   100.0%        |      0.0%      |
| Places 365         |   100.0%        |      0.0%      |


2. **CIFAR-10**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| MNIST              |   100.0%        |      0.0%      |
| SVHN               |   100.0%        |      0.0%      |
| DTD                |   66.2%         |     97.3%      |
| Places 365         |   92.6%         |     34.0%      |

3. **TinyImageNet**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| iNaturalist        |   100.0%        |      0.0%      |
| DTD                |   99.5%         |      0.9%      |
| OpenImage-O        |   100.0%        |      0.0%      |

4. **ImageNet-1K**

| OOD Dataset        |      AUROC      |     FPR@95     |
| ------------------ |---------------- | -------------- |
| iNaturalist        |   100.0%        |      0.0%      |
| DTD                |   87.6%         |      84.0%     |
| OpenImage-O        |   99.4%         |      0.3%      |

## Experiment Tracking

The code examples are setup to use [Weights & Biases](https://wandb.ai/home) as a tool to track your training runs. Please refer to the [`full documentation`](https://docs.wandb.ai/quickstart) if required or follow the following steps:

1. Create an account in [Weights & Biases](https://wandb.ai/home)
2. **If you have installed the requirements you can skip this step**. If not, activate the conda environment and the virtualenv and run:

    ```bash
    pip install wandb
    ```
3. Run the following command and insert you [`API key`](https://wandb.ai/authorize) when prompted:

    ```bash
    wandb login
    ```

## Repository Information

### Documentation

Full documentation is available here: [`docs/`](docs).

### Dev

See the [Developer](docs/DEVELOPER.md) guidelines for more information.

### Contributing

Contributions of any kind are welcome. Please read [CONTRIBUTING.md](docs/CONTRIBUTING.md]) for details and
the process for submitting pull requests to us.

**Please read [MODELRULES.md](docs/MODELRULES.md) for details on how you should build your models for this repository.**

## License

This project is licensed under the terms of the `CC-BY-4.0` license.
See [LICENSE](LICENSE) for more details.

## References

All the repositories used to generate this code are mentioned in each of the corresponding files. We would like to list them in no particular order:

- [PyTorch-VAE](https://github.com/AntixK/PyTorch-VAE)
- [conditional-GAN](https://github.com/TeeyoHuang/conditional-GAN)
- [PresGANs](https://github.com/adjidieng/PresGANs/)


## Citation

If you publish work that uses DisCoNet, please cite DisCoNet.

**BibTex information will be added later**
