## Introduction
This repo is an implementation of our paper TestRank. Currently, we support 'cifar10, svhn, and stl10' dataset.

In general, we have the following files:
   --train_classifier.py; which is used to train DL classifiers
   --selection.py; This is the core of TestRank and the selection strategy is in this file.
   --byol/* contains the code to train the unsupervised feature extractor

## Environment
The environment info in stroed in "environment.yml'.
The environment can be restored by calling "conda env create -f environment.yml".

Please run the code on GPU.

## Step to run TestRank
The only file you need to edit for different configuration is 'run.sh'.
There are mainly 3 steps involved:
### Prepare DL model under test: 
   We papare pretrained classifers for STL10 dataset. The link to download the data is: https://drive.google.com/drive/folders/1sLXG9wlLHjUVF_FRNW0nX9h4M7EXGjBS?usp=sharing. Please download the three classifiers "resnet34_0_b.t7/resnet34_1_b.t7/resnet34_2_b.t7" to folder "./checkpoint/stl10/ckpt_bias/". Please check your md5 after you download the weight files:
      --resnet34_0_b.t7: e6e518998e9be957c77afe8a33aff590
      --resnet34_1_b.t7: 44a5f49cc833421f0e489a5e0aa37bac
      --resnet34_2_b.t7: 388598538a54aa2f96c082c07a08fbc3

   However, if you want to train your own classifiers. The code used to train the model are resides in the 'train_classifier.py' file. 
   To train three different DL models for each dataset, run "./run.sh trainm".
   If you want to change the dataset, please modify 'DATASET=dataset_name' (line 51) with the desired dataset name in the 'run.sh' file.
   The trained model will be stored in path './checkpoint/dataset/ckpt_bias/*'. 
      -- Each model will be assigned with a unique ID (e.g. 0, 1, 2). 
   
### Prepare feature extractor
   We papare a pretrained feature extractor for the STL10 dataset. The link to download the feature extractor is: https://drive.google.com/file/d/1KSqCBgaxow93gogIlS5BuwFxcOHIvF2c/view?usp=sharing. Please put the downloaded file in the "./byol/checkpoints/official-stl10/" folder.

   The md5 of this file is: fe7e3bc9f846e0250c7e6951034ec13f

### Perform test selection: "./run.sh selection"
   Call the 'run.sh' file with argument 'selection':
   <mark>./run.sh selection</mark>

      The parameter are explained as following:
      python selection.py \
                  --dataset $DATASET \                   # specify the dataset to use
                  --manualSeed ${RANDOM_SEED} \          # random seed
                  --model2test_arch $MODEL2TEST \        # architecture of the model under test (e.g. resnet18)
                  --model2test_path $MODEL2TESTPATH \    # the path storing the model weights 
                  --model_number $MODEL_NO \             # which model to test, model 0, 1, or 2?
                  --save_path ${save_path} \             # The result will be stored in here
                  --data_path ${DATA_ROOT} \             # Dataset root path
                  --graph_nn \                           # use graph neural network in testrank
                  --feature_extractor_id ${feature_extractor_id} \ # type of feature extractor
                  --no_neighbors ${no_neighbors} \       # number of neighbors in to constract graph
                  --learn_mixed                          # use mlp to combine intrinsic and contextual attributes; otherwise they are brute force combined (multiplication two scores)
                  --baseline_gini
   <mark>The result is stored in 'save_path/date/dataset_model/xxx_result.csv' in where xxx stands for the selection method used (e.g. for testrank, the file would be gnn_result.csv)
   The TRC value is in the last column, and the forth column shows the corresponding budget in percent.<mark>

   - To compare with baselines, please specify the corresponding baseline method (e.g. baseline_gini, baseline_uncertainty, baseline_dsa, baseline_mcp):
   e.g. To compare with the baseline deepgini, you need to change the 'run.sh' file into the following:
      python selection.py \
               --dataset $DATASET \                   
               --manualSeed ${RANDOM_SEED} \          
               --model2test_arch $MODEL2TEST \        
               --model2test_path $MODEL2TESTPATH \   
               --model_number $MODEL_NO \            
               --save_path ${save_path} \            
               --data_path ${DATA_ROOT} \             
               --baseline_gini 
   - To evaluate different models, change the MODEL_NO to the corresponding mdoel: [0, 1, 2]