This code is for the paper "Label Distribution Shift-Aware Prediction Refinement for Test-Time Adaptation (DART)".
This code is based on the publicly released code "https://github.com/DequanWang/tent" and "https://github.com/locuslab/tta_conjugate".

### Dependencies ######
Python       3.8.17
PyTorch      2.0.1 
Torchvision  0.15.2 
CUDA        11.5
NUMPY     1.24.4 

##### Data #####
[training dataset] You can download clean CIFAR-10/100 training dataset in robustbench.
[test dataset] You can download CIFAR-10/100C in "https://github.com/hendrycks/robustness".

##### Pre-training #####
We use publicly released trained models and codes for a fair comparison. 
Specifically, for CIFAR-10/100 (https://github.com/locuslab/tta_conjugate), we train with 200 epochs, batch size 200, SGD optimizer, learning rate 0.1, momentum 0.9, and weight decay 0.0005.

##### Generate online imbalanced indicies for CIFAR-10/100C-imb ########
You can generated online imbalanced indices with the following commands:
e.g ) python generate_shifted_sample_indices.py

##### intermediate time training ########
You can train a prediction refinement module g_phi with the following commands.
e.g.) python int_time_dart.py

Then, the trained g_phi and training logs are recored in "./eval_results/trained_gphi/cifar10/trained_gphi.pt" and "./eval_results/trlog/cifar10/trlog", respectively.

##### Test time adaptation ########
While testing, we adapt trained classifiers with the trained g_phi. To obtain the experimental results of naive TTA and DART-applied TTA methods, use the following commands.
e.g.) 
For BNAdapt, run "python tta_dart.py --tta_method=bnadapt"
For TENT, run "python tta_dart.py --tta_method=tent"

You can use the test-time adaptation algorithms such as LAME, PL, NOTE, ODS, DELTA which are presented in the "utils_baseline_dart.py".

Then, the test logs are recored in "./eval_results/tta/cifar10/tslog".

And, you can obtain the experimental results of DART-split by running 'int_time_dart_split.py' and 'tta_dart_split.py' in the same procedure as described above.