
# The implement of 'Learning Sample Difficulty from Pre-trained Models for Reliable Prediction'

Thank you for taking the time to review our code and datasets. This readme describes how to run our proposed method. Notice: Our programs are build based on GPU, it is better to test them in GPU.

## Environment Settings

- numpy==1.20.1
- pytorch_pretrained_vit==0.0.7
- scikit_learn==1.2.1
- scipy==1.6.2
- timm==0.5.4
- torch==1.9.1
- torchvision==0.10.1
- tqdm==4.59.0
- transformers==4.18.0

## How to run

After you have downloaded the repository, you can train the model under sample difficulty-aware entropy regularization by running the example script below.

### First, we can generate Gaussian mean and covariance matrix for computing sample difficulty score by the following command:

* For ImageNet1k

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python generate_mahalanobis.py --dataset imagenet -a clip32 --maha_file ./ssl/maha_dict_clip32_imagenet_512.npy --batch_size 4096 --num_classes 1000
```


* For CIFAR-100

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python generate_mahalanobis.py --dataset cifar100 -a clip32 --maha_file ./ssl/maha_dict_clip32_cifar100_512.npy --batch_size 4096 --num_classes 100
```

### Then, we can train the model under the sample difficulty-aware entropy regularization by running the example script below:

* For ImageNet1k

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --master_port 7668 --nproc_per_node=8 train_imagenet_sample_weight.py --batch-size 1024 --lr 0.4 --num_classes 1000 --epochs 400 --loss_type aer_e02_07t_clip32 -a clip32 --maha_file ./ssl/maha_dict_clip32_imagenet_512.npy --e_lambda 0.2 --warmup -1 --method aer --left 1. --right 1. --T 0.7 --net resnet34 > imagenet_aer.log &
```


* For CIFAR-100

```bash
CUDA_VISIBLE_DEVICES=0 python train_cifar_sample_weight.py --batch_size 256 --save_model --num_epochs 200 --milestones 60 120 150 --lr 0.10 --gamma 0.2 --ensemble_num 1 --dataset cifar100 -a clip32 --maha_file ./ssl/maha_dict_clip32_c100_512.npy --e_lambda 0.3 --loss_type aer_e03_07t_clip32 --warmup -1 --method aer --left 1. --right 1. --T 0.7 > cifar100_aer.log &

```
