# The official implementation of "Gradient Rectification for Robust Calibration under Distribution Shift"


*Thank you for your careful review of our work.*

### Training

The baseline method training uses the train.py file. To train using resnet50 based on cross-entropy on cifar10, use the following command:
```
CUDA_VISIBLE_DEVICES=0 python train.py \
--dataset cifar10 \
--model resnet50 \
--loss cross_entropy 
```
The trained model weights will be saved in the "./checkpoints/cifar10" directory.


Our method FGR training uses the train_gard.py file. For models trained on CIFAR10 using ResNet50 based on cross-entropy, use the following command to fine-tune them based on our method:
```
CUDA_VISIBLE_DEVICES=0 python train_gard.py \
--dataset cifar10 \
--model resnet50 --first-milestone 40 --second-milestone 70 -e 100 \
--loss dual_focal_loss --gamma 5.0  --lr 0.01 \
--load --saved_model_name resnet50_cross_entropy_350.model  --freeze \
--use-corruption --corruption-prob 0.05 \
--use-grad-correction  
```
The trained model weights will be saved in the "./checkpoints/cifar10_gard" directory.


### Evaluation

For the above Resnet50 trained based on cross-entropy, use the following commands to evaluate the results of the original model and the model using temperature scaling on the in-distribution and distribution-shifted test sets:
```
CUDA_VISIBLE_DEVICES=0 python evaluate.py \
--dataset cifar10 \
--model resnet50 \
--save-path ./checkpoints/ \
--saved_model_name resnet50_cross_entropy_350.model  --remark ce
```
For our model, use the following command to evaluate:
```
CUDA_VISIBLE_DEVICES=0 python evaluate.py \
--dataset cifar10 \
--model resnet50 \
--saved_model_name resnet50_dual_focal_loss_gamma_5.0_corrupted_0.05_450.model \
--gard  --remark LP_gard_corrupt
```

* More training and evaluation commands are located in the "./train_scripts" and "./train_scripts folders."


The template for this project comes from the official implementation of "Calibrating Deep Neural Networks using Focal Loss", https://github.com/torrvision/focal_calibration.