This code is for the paper "Distribution Shift-Aware Prediction Refinement for Test-Time Adaptation (DART)".
This code is based on the publicly released code "https://github.com/DequanWang/tent" and "https://github.com/locuslab/tta_conjugate".

### Dependencies ######
Python       3.8.13
PyTorch      2.0.1 
Torchvision  0.15.2 
CUDA        11.5
NUMPY     1.24.4 

##### Data #####
[training dataset]You can download clean CIFAR-10/100 training dataset in robustbench, and ImageNet training dataset from "https://www.image-net.org/download.php".
[test dataset]You can download CIFAR-10/100C and ImageNet-C in "https://github.com/hendrycks/robustness". You can download PACS dataset as described in "https://github.com/matsuolab/T3A".

##### Pre-training #####
We use publicly released trained models and codes for a fair comparison. 
Specifically, for CIFAR-10/100 and digit classification(https://github.com/locuslab/tta_conjugate), we train with 200 epochs, batch size 200, SGD optimizer, learning rate 0.1, momentum 0.9, and weight decay 0.0005. 
For PACS, we use released pre-trained models of TTAB (https://github.com/LINs-lab/ttab). 

##### intermediate time training ########
You can train distribution shift aware module g_\phi with the following commands
e.g.) python int_time_cifar10.py

Then, the trained g_phi and training logs are recored in "./eval_results/trained_gphi/cifar10/trained_gphi.pt" and "./eval_results/trlog/cifar10/trlog" ,respectively.

##### Test time adaptation ########
While testing, we adapt trained classifiers with the trained g_phi
e.g.) 
for CIFAR-10C-LT with \rho=1, python tta_cifar10.py --rho=1.
for CIFAR-10C-LT with \rho=10, python tta_cifar10.py --rho=0.1
for CIFAR-10C-LT with \rho=100, python tta_cifar10.py --rho=0.01


You can use the test-time adaptation algorithms such as TENT, TENT+ours which are presented in the utils_baseline.py

Then, the test logs are recored in "./eval_results/tta/cifar10/tslog"