# Adversarially Robust Anomaly Detection through Spurious Negative Pair Mitigation

PyTorch implementation of
["**Adversarially Robust Anomaly Detection through Spurious Negative Pair Mitigation**"]() (ICLR 2025 Submitted) by 


<p align="center">
    <img src=figures/cobra-pipline.jpg width="900"> 
</p>

## 1. Requirements
### Environments
Currently, requires following packages
- python 3.6+
- torch 1.4+
- torchvision 0.5+
- CUDA 10.1+
- scikit-learn 0.22+
- tensorboard 2.0+
- [torchlars](https://github.com/kakaobrain/torchlars) == 0.1.2 
- [pytorch-gradual-warmup-lr](https://github.com/ildoonet/pytorch-gradual-warmup-lr) packages 
- [diffdist](https://github.com/ag14774/diffdist) == 0.1 

### Datasets 
CIFAR-10, CIFAR100, Imagenet30, FashionMnist, SVHN, MNIST, CityScapes

For ImageNet-30, please download the following datasets to `~/data`.
* [ImageNet-30-train](https://drive.google.com/file/d/1B5c39Fc3haOPzlehzmpTLz6xLtGyKEy4/view),
[ImageNet-30-test](https://drive.google.com/file/d/13xzVuQMEhSnBRZr-YaaO08coLU2dxAUq/view)


## 2. Training

### One-Class Anomaly Detection
To train One-Class Anomaly Detection models in the paper, run this command:

```train
python train.py --dataset <DATASET> --model <NETWORK> --epochs 100 --shift_trans_type rotation --batch_size 128 --one_class_idx <One-Class-Index> --epsilon 0.0156
```

> Option 
> * one_class_idx: denotes the in-distribution of one-class training.
> * epsilon: config of PGD
### Unlabeled Multi-Class Anomaly Detection
To train unlabeled multi-class model in the paper, run this command:

```train
python train.py --dataset <Normal DATASET> --ood_dataset <Anomaly Dataset> --model <NETWORK> --epochs 100 --shift_trans_type rotation --batch_size 128 --epsilon <PGD config of NAT & PAT value>
```


## 3. Evaluation

We provide the checkpoint of the COBRA pre-trained model. Download the checkpoint from the following link:
- [COBRA Checkpoints]()


### One-Class Anomaly Detection
To evaluate our model on One-Class Anomaly Detection setting, run this command:

```eval
python eval.py --out_attack --in_attack --dataset <DATASET> --model <NETWORK> --ood_score COBRA --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --one_class_idx <One-Class-Index> --eps <PGD Attack Config> --load_path <MODEL_PATH> --batch_size 64 --test_batch_size 64
```

> Option 
>- out_attack: Just Anomaly data will be under PGD-100 attack
>- in_attack: Just Normal data will be under PGD-100 attack
>- one_class_idx: denotes the in-distribution of one-class evaluation.
> The resize_factor & resize fix option fix the cropping size of RandomResizedCrop().

### Unlabeled Multi-Class Anomaly Detection
To evaluate our model on Unlabeled Multi-Class Anomaly Detection setting, run this command:

```eval
!python eval.py --out_attack --in_attack --ood_dataset $ood_dataset --dataset <DATASET> --model <NETWORK> --ood_score COBRA --shift_trans_type rotation --print_score --ood_samples 10 --resize_factor 0.54 --resize_fix --eps <PGD Attack Config> --load_path <MODEL_PATH> --batch_size 64 --test_batch_size 64
```

> Option 
>- out_attack: Just Anomaly data will be under PGD-100 attack
>- in_attack: Just Normal data will be under PGD-100 attack
>- one_class_idx: denotes the in-distribution of one-class evaluation.

## 4. Results

Our model achieves the following performance on:

### One-Class Anomaly Detection

| Dataset        | Clean (AUROC)            |  PGD-1000(AUROC) |
| --------------|------------------ | --------------|
| MVTechAD        | 89.1%       |      75.1%    |
| CIFAR-10        | 83.7%       |      62.3%    |
| CIFAR-100     | 76.9%       |      51.7%    |
| ImageNet30    | 85.2%       |      57%    |

We only show 3 dataset result with epsilon=4/255 PGD attack in this repo. For other setting, please see our paper.

### Unlabeled Multi-Class Anomaly Detection


| Dataset        | OOD Dataset           | Clean(AUROC)   | PGD-1000(AUROC) |
| --------------|------------------ |---------------|--------------|
| CIFAR-10     | CIFAR-100          | 76%     |     63.3%    |
| CIFAR-10    | ImageNet30          | 85.1%     |     53.1%    |
| CIFAR-100    | CIFAR-10          | 48.7%     |     27.5%    |
| CIFAR-100    | ImageNet30          | 69.1     |     32.3%    |

We only show limited number of result in this repo. For other OOD dataset results, see our paper.

## Citation

