# Sharpness-Aware Geometric Defense for Robust Out-Of-Distribution Detection

This codebase provides a Pytorch implementation for the paper CIDER: [How to Exploit Hyperspherical Embeddings for Out-of-Distribution Detection?](https://openreview.net/forum?id=aEFaE0W5pAd) at ICLR 2023.

### Abstract

Out-of-distribution (OOD) detection is a vital technique supporting safe and reliable model deployment. 
Contemporary OOD algorithms using geometry projection can detect OOD or adversarial samples from clean in-distribution (ID) samples. However, this setting regards adversarial ID samples as OOD, leading to incorrect OOD predictions. Existing efforts on OOD detection under adversarial conditions are still very limited. In this paper, we develop a robust OOD detection method that distinguishes adversarial ID samples from OOD ones. The sharp loss landscape created by adversarial training hinders model convergence, impacting the latent embedding quality for OOD score calculation. Therefore, we introduce a **Sharpness-aware Geometric Defense (SaGD)** framework to smooth out the rugged adversarial loss landscape in the projected latent geometry. Enhanced convergence in geometric embedding enables accurate characterization of ID data, benefiting OOD detection against adversarial attacks. We use Jitter-based perturbation in adversarial training to extend the defense ability against unseen attacks. Our SaGD framework significantly improves FPR and AUC over the state-of-the-art defense approaches in differentiating CIFAR-100 from six other OOD datasets under various attacks. We further examine the effects of perturbations at various adversarial training levels, revealing the relationship between the sharp loss landscape and adversarial OOD detection.

### Illustration

![fig1](readme_figs/framework.png)



## Quick Start


### Data Preparation

The default root directory for ID and OOD datasets is `datasets/`. We consider the following (in-distribution) datasets: CIFAR-10, CIFAR-100, and ImageNet-100. 

**ID datasets** CIFAR-10, CIFAR-100
**OOD datasets** We use TinyImageNet, Places365, LSUN-C (LSUN), LSUN-R (LSUN_resize), iSUN, and Textures (dtd).

OOD datasets can be downloaded via the following links (source: [ATOM](https://github.com/jfc43/informative-outlier-mining/blob/master/README.md)):

* [TinyImageNet](https://www.kaggle.com/c/tiny-imagenet): download it and place it in the folder `datasets`
* [Places365](http://data.csail.mit.edu/places/places365/test_256.tar): download it and place it in the folder of `datasets/ood_datasets/places365/test_subset`. We randomly sample 10,000 images from the original test dataset. 
* [LSUN-C](https://www.dropbox.com/s/fhtsw1m3qxlwj6h/LSUN.tar.gz): download it and place it in the folder of `datasets/small_OOD_dataset/LSUN`.
* [LSUN-R](https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz): download it and place it in the folder of `datasets/small_OOD_dataset/LSUN_resize`.
* [iSUN](https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz): download it and place it in the folder of `datasets/small_OOD_dataset/iSUN`.
* [Textures](https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz): download it and place it in the folder of `datasets/small_OOD_dataset/dtd`.


The directory structure looks like:

```python
datasets/
---CIFAR10/
---CIFAR100/
---small_OOD_dataset/
------dtd/
------iSUN/
------LSUN/
------LSUN_resize/
------places365/
------tinyimagenet/
```


## Training and Evaluation 

### Model Checkpoints

**Checkpoints** 
```
checkpoints/
---CIFAR-10/	 	
------ckpt_name/
------checkpoint_500.pth.tar
---CIFAR-100/	 	
------ckpt_name/
------checkpoint_500.pth.tar
```

The following scripts can be used to evaluate the OOD detection performance:

```
bash scripts/eval_ckpt_cifar10.sh ckpt_name
bash scripts/eval_ckpt_cifar100.sh ckpt_name
```

The evaluation under different adversarial attacks is implemented with the parameter input in `eval_ckpt_cifar10.sh` and `eval_ckpt_cifar100.sh`. Please add `--attack "pgd"` and this argument supports `pgd`, `fgsm`, `cw`, `FAB` and `jitter`. These attacks depend on [Torchattack](https://github.com/Harry24k/adversarial-attacks-pytorch). 


**Training** 

```
bash scripts/train_cider_cifar10.sh
bash scripts/train_cider_cifar100.sh
```




### Citation

If you find our work useful, please consider citing our paper:

```

```

### Acknowledge 
We acknowledge the released code from [CIDER](https://github.com/deeplearning-wisc/cider), [Hyperbolic Image Embedding](https://github.com/leymir/hyperbolic-image-embeddings), and [RSAM](https://arxiv.org/pdf/2309.17215.pdf).
