# Neighbors search implementations

## Algorithms

We integrate three neighbor list routines in our codebase:

- `jaxmd_vmap`: refers to using the original cell list-based implementation from the [JAX-MD](https://github.com/jax-md/jax-md) library.
- `jaxmd_scan`: refers to using a more memory-efficient implementation of the JAX-MD function. We achieve this by partitioning the search over potential neighbors from the cell list-based candidate neighbors into `num_partitions` chunks. We need to define three variables to explain how our implementation works:
    - $X \in \mathbb{R}^{N\times d}$ - the particle coordinates of $N$ particles in $d$ dimensions.
    - $h \in \mathbb{N}^{N}$ - the list specifying to which cell a particle belongs.
    - $L \in \mathbb{N}^{C \times cand}$ - list specifying which particles are potential candidates to a particle in cell $c \in [1, ..., C]$. The number of potential candidates $cand$ is the product of the fixed cell capacity (needed for jit-ability) and the number of reachable cells, e.g. 27 in 3D.

    The `jaxmd_vmap` implementation essentially instantiates all possible connections by creating an object of size $N \cdot cand$, and only after all distances between potential neighbors have been computed the edge list is pruned to its actual size being ~6x smaller in 3D. This factor comes from the fact that the cell size is approximately equal to the cutoff radius and if we split a unit cube into $3^3$ cells, then the volume of a sphere with $r=1/3$ will be around $1/6$ the volume of the cube. By splitting $X$ and $h$ into `num_partitions` parts and iterating over $L$  with a `jax.lax.scan` loop, we can remove $~5/6$ of the edges before putting them together into one list.

- `matscipy`: to enable computations over systems with variable number of particles, none of the above implementation can be used and that is why we wrote a wrapper around the [matscipy](https://github.com/libAtoms/matscipy) neighbos search routine `matscipy.neighbours.neighbour_list`. This is again a cell list-based algorithms, however only available on CPU. Our wrapper essentially mimics the behavior of the JAX-MD function, but pads all non-existing particles to the maximal number of particles in the dataset.

## Performance

> Note: We observe reasonable performance from each of these implementations with up to ~10k particles, but more investigation need to be conducted towards comparing these algorithms on larger systems. Remember that we limit the system size of our benchmark datasets to 10k for memory reasons on the GNN side, and scaling eventually requires domain decomposition and parallelization.

### `vmap` vs `scan`

We compare the largest number of particles whole neighbor list computation fits into memory. We ran the script [`neighbours_search/scaling.sh`](./scaling.sh) on an A6000 GPU with 48GB memory and observed that the default vectorized implementation (`vmap`) can handle up to 1M particles before running out of memory, while our `scan` implementation reaches 3.3M. This happens at almost no additional time cost and holds for both allocating a system and updating it after jit compilation.

The output of the `scaling.sh` script looks like follows:

```
###################################################
###################################################
Start with Nx=100, mode=allocate, backend=jaxmd_vmap
Finish with 1000000 particles and 141283880 edges!
Start with Nx=102, mode=allocate, backend=jaxmd_vmap
Start with Nx=104, mode=allocate, backend=jaxmd_vmap
Start with Nx=106, mode=allocate, backend=jaxmd_vmap
Start with Nx=108, mode=allocate, backend=jaxmd_vmap
Start with Nx=110, mode=allocate, backend=jaxmd_vmap
###################################################
Start with Nx=150, mode=allocate, backend=jaxmd_scan
Finish with 3375000 particles and 476838165 edges!
Start with Nx=152, mode=allocate, backend=jaxmd_scan
Start with Nx=154, mode=allocate, backend=jaxmd_scan
Start with Nx=156, mode=allocate, backend=jaxmd_scan
Start with Nx=158, mode=allocate, backend=jaxmd_scan
Start with Nx=160, mode=allocate, backend=jaxmd_scan
###################################################
###################################################
Start with Nx=100, mode=update, backend=jaxmd_vmap
Finish with 1000000 particles and 141283880 edges!
Start with Nx=102, mode=update, backend=jaxmd_vmap
Start with Nx=104, mode=update, backend=jaxmd_vmap
Start with Nx=106, mode=update, backend=jaxmd_vmap
Start with Nx=108, mode=update, backend=jaxmd_vmap
Start with Nx=110, mode=update, backend=jaxmd_vmap
###################################################
Start with Nx=150, mode=update, backend=jaxmd_scan
Finish with 3375000 particles and 476838165 edges!
Start with Nx=152, mode=update, backend=jaxmd_scan
Start with Nx=154, mode=update, backend=jaxmd_scan
Start with Nx=156, mode=update, backend=jaxmd_scan
Start with Nx=158, mode=update, backend=jaxmd_scan
Start with Nx=160, mode=update, backend=jaxmd_scan
```

### `matscipy`

The matscipy implementation is extremely fast for small systems (10k particles) and doesn't take any GPU memory for the construction of the edge list, however, as the systems size increases, copying memory between CPU and GPU becomes a bottleneck. Also, it seems like matscipy uses a single CPU computation which is rather limiting.
