# DeepSplit
We provide a demonstration of the DeepSplit method proposed 'DeepSplit: Scalable Verification of Deep Neural Networks via Operator Splitting'. The implementation of ADMM is in the folder ADMM and the scalability comparison experiment on ResNet18 in Section 4.2 is provided in Examples/LiRPA_comparison. The pre-trained ResNet18 is saved in Examples/LiRPA_comparison/resnet18.pth.

## Implementation of ADMM
The implementation of ADMM is in [admm.py] where we define the ADMM modules according to the given neural network structure. For each layer in a neural network such as the linear layer, ReLU layer, batch normalization layer, etc., we first find its corresponding `ADMMLayer` for which the projection operation onto the convex hull of their graphs are implemented. Then the `ADMM_forward_block` class in admm.py assembles multiple ADMM layers arranged sequentially which can execute the ADMM algorithm update and compute the primal/dual residuals/tolerances for the assembled layers. The `ADMM_residual_block` models the residual blocks in ResNet18 and consists of three components: an `ADMM_forward_block` that models the identity connection, an `ADMM_forward_block` that models the residual connection, and an `ADMM_sum_block` that implements the summation. Each of these components are able to run the ADMM updates and compute primal/dual residuals/tolerances. In the end, we assemble all the `ADMM_forward_block`s and `ADMM_res_block`s into one `ADMM_session`, which models the whole ResNet18 and solves the LP relaxation with ADMM. 

After constructing the `ADMM_session`, the ADMM algorithm is implemented in `run_ADMM` which does a forward pass through all the `ADMM_forward_block`'s and `ADMM_res_block`s contained in the `ADMM_session` iteratively until the stopping criterion is met. For feedforward fully-connected or convolutional neural networks, simply using `ADMM_forward_block` is sufficient to encode the whole network and construct the `ADMM_session`. Therefore, our proposed DeepSplit method is modularized and is easy to adapt to networks with general computational graph. 

The class `Layer_section` is a helper of generating an `ADMM_forward_block` from a list of neural network layers arranged sequentially in a list and the functions `generate_layer_sections`, `init_ADMM_session` build on top of it to finally initialize an `ADMM_session` for running the ADMM.

## Running the codes
A running example is presented in Examples/LiRPA_comparison where for the first test example from CIFAR10, we compare the output bounds of ResNet18 obtained from ADMM and [LiRPA](https://github.com/KaidiXu/auto_LiRPA) with all the preactivation bounds obtained from LiRPA. All the running data are saved in examples\LiRPA_comparison\ADMM_LiRPA_comparison and main_output_bounds_comparison_plot.py generates Figure 2 and Figure 9 in the paper. The CIFAR10 dataset will be automatically downloaded when loading it.

To see how the ADMM works, you can simply run main_ADMM_LiRPA_output_bounds_comparison.py. There is no need to install LiRPA for running this script. On the other hand, main_LiRPA_layerwise_bounds.py computes all the preactivation bounds from LiRPA. To run this script, you need to download [LiRPA](https://github.com/KaidiXu/auto_LiRPA) first and replace the codes in [Line 295](https://github.com/KaidiXu/auto_LiRPA/blob/c8935c6d22cd76e137b1a9b1b3ea67f7d234601d/auto_LiRPA/bound_general.py#L295) in [auto_LiRPA/bound_general.py](https://github.com/KaidiXu/auto_LiRPA/blob/master/auto_LiRPA/bound_general.py)
```
def _convert_nodes(self, model, global_input):
        global_input_cpu = self._to(global_input, 'cpu')
        model.train()
        model.to('cpu')
        nodesOP, nodesIn, nodesOut, template = parse_module(model, global_input_cpu)
        model.to(self.device)
```
by
```
def _convert_nodes(self, model, global_input):
        global_input_cpu = self._to(global_input, 'cpu')
        model.eval()
        model.to('cpu')
        nodesOP, nodesIn, nodesOut, template = parse_module(model, global_input_cpu)
        model.train()
        model.to(self.device)
```

