# BATTA-RL : Binary-feedback Active Test-Time Adaptation

This is the PyTorch Implementation of "BATTA-RL : Binary-feedback Active Test-Time Adaptation".

## Installation Guide

1. Download or clone our repository
2. Set up a python environment using conda (see below)
3. Prepare datasets (see below)
4. Run the code (see below)

## Python Environment

We use [Conda environment](https://docs.conda.io/).
You can get conda by installing [Anaconda](https://www.anaconda.com/) first.

We share our python environment that contains all required python packages. Please refer to the `./batta.yml` file

You can import our environment using conda:

    conda env create -f batta.yml -n batta

Reference: https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-from-an-environment-yml-file

## Prepare Datasets

To run our codes, you first need to download at least one of the datasets. Run the following commands:

    $cd .                               #project root
    $. download_cifar10.sh              #download CIFAR10 datasets
    $. download_cifar100.sh             #download CIFAR100 datasets
    $. download_tiny_imagenet.sh        #download Tiny-ImageNet datasets
    $. download_pacs.sh                 #download PACS datasets

## Run

### Prepare Source model
You first need to create the directory for pre-trained weights :

    $cd .                               #project root
    $mkdir pretrained_weights           #create blank directory for pre-trained weights

We prepare the pre-trained and fine-tuned model for adaptation at [Google Drive link](https://drive.google.com/drive/folders/1gJt0uRVQRVML-kk6aLgLFLMFxnUJ-k4y?usp=sharing). The pre-trained models are directly brought from ATTA https://github.com/divelab/ATTA. Make sure that the pre-trained and fine-tuned model files are in the `./pretrained_weights` folder:
```
BATTA
│   README.md
│   tta.sh
│   main.py
│   download_pacs.sh
|   ...
|
└───pretrained_weights
│   └───pacs
│   |   │   normal_cp
│   |   │   enhanced_cp
│   |
│   └───tiny-imagenet
│       │   normal_cp
│       │   enhanced_cp
│   
...
```
### Run Test-Time Adaptation (TTA) & Estimate Accuracy

Given source models are available, you can run TTA via:

    $. tta.sh                       #Run online PACS as default.

You can specify which dataset and which method in the script file.

## Log

### Raw logs

In addition to console outputs, the result will be saved as a log file with the following structure: `./log/{DATASET}/{METHOD}_outdist/{TGT}/{LOG_PREFIX}_{SEED}_{DIST}/online_eval.json`

### Obtaining results

In order to print the accuracy estimation mean-absolute-errors(%) on test set, run the following commands:

    #### print the result in continual TTA setting. ####
    $python print_acc.py --dataset pacs --target BATTA --seed 0 1 2 --cont
   
    #### print the result in online TTA setting. ####
    $python print_acc.py --dataset pacs --target BATTA --seed 0 1 2  


## Tested Environment

We tested our codes under this environment.

- OS: Ubuntu 20.04.4 LTS
- GPU: NVIDIA GeForce TITAN RTX
- GPU Driver Version: 470.57
- CUDA Version: 11.4
