# Towards Efficient and Scalable Training of Differentially Private Deep Learning

This repository contains the code for reproducing the experiments in Towards Efficient and Scalable Training of Differentially Private Deep Learning.

There are two sets of experiments:
-   torch_exp
    -   The pipeline of all torch implementations is on the file pipeline_torch.py
    -   To record the results, use thr_record.py which creates a csv file with the throughput and accuracy results.
-   jax_exp
    -   Contains the different jax implementations in different files.
        -   pipeline_jax_naive.py is the original naive implementation, using the same BatchManager for splitting the physical batches.
        -   For it, use the thr_record.py to record the results.
            ```
            python3 Towards-Efficient-Scalable-Training-DP-DL/jax_exp/thr_record.py --n_workers 16 --bs 32492 --phy_bs=32 --seed 20 --epsilon 8 --epochs 2 --lr=0.00031 --grad_norm 4.637 --ten 100 --clipping_mode='private-mini' --file 'mask_dp_test.csv' --model 'google/vit-base-patch16-224' --normalization='True'
            ```
        -   jax_mask_efficient.py is the implementation that is more efficient and the one presented in the paper. It can be executed as a python file.
        -   private_vit.py and private_resnet.py are helper files to load the respective models.

We use an HPC environment and train on V100 and A100 GPUs. 