This code is for the paper "Distribution Shift-Aware Prediction Refinement for Test-Time Adaptation (DART)".
This code is based on the publicly released code "https://github.com/mr-eggplant/SAR".  (Follow the description in this repository)

### Dependencies ######
Python       3.8.13
PyTorch      2.0.1 
Torchvision  0.15.2 
CUDA        11.5
NUMPY     1.24.4 

##### Data #####
You can download ImageNet training dataset in "https://www.image-net.org/download.php" and ImageNet-C in "https://github.com/hendrycks/robustness"

##### intermediate time training ########
You can train distribution shift aware module g_\phi with the following commands
e.g.) python3 int_time.py --data <path of ImageNet training datasets> --data_corruption <path of ImageNet-C>

Then, the trained g_phi and training logs are recored in "./eval_results/trained_gphi/cifar10/trained_gphi.pt" and "./eval_results/trlog/cifar10/trlog" ,respectively.
##### test-time training #######
While testing, we adapt trained classifiers with the trained g_phi
e.g.) python3 main.py --method=<bnadapt_dart/tent_dart> --data_corruption <path of ImageNet-C>

Then, the test logs are recored in "./eval_results/tta/tslog"