About the files:

code contains:

main.py: run main.py to get the result
data_prepare.py: get data, here we use Mnist, Fashion_Mnist, and cifar10 feature as VGG16 net output
model.py: contains L_relu_enn model, relu_enn model, INN model
train.py: contains our training method
utils.py: math function to get tau
data: dataset we use in the experiment
error_wrt_p_n.m: approximation error of \|G^*-\overline G\| w.r.t p and n.

pretrained: pretrained VGG16 model 

Usage:

usage: python main.py [-h] [--dataset DATASET] [--model MODEL] [--dim DIM] [--lr LR] [--epoch EPOCH] [--batch_size BATCH_SIZE] [--save_path SAVE_PATH] [--dataset_path DATASET_PATH] [--device DEVICE] [--vgg_path VGG_PATH] [--output_path OUTPUT_PATH]

example: python main.py --dataset cifar10 --dim 1024 --model l_relu_enn

You can get result of figure3 by just using bash fig3_mnist.sh or bash fig3_fashion_mnist.sh or bash fig3_cifar10.sh