.. _knn_benchmark:

Benchmark the performance of KNN algorithms
===========================================

In this doc, we benchmark the performance on multiple K-Nearest Neighbor algorithms implemented by :func:`dgl.knn_graph`.

Given a dataset of ``N`` samples with ``D`` dimensions, the common use case of KNN algorithms in graph learning is to build a KNN graph by finding the ``K`` nearest neighbors for each of the ``N`` samples among the dataset.

Empirically, the three parameters, ``N``, ``D``, and ``K``, all have impact on the computation cost. To benchmark the algorithms, we pick a few represensitive datasets to cover most common scenarios:

* A synthetic dataset with mixed gaussian samples: ``N = 1000``, ``D = 3``.
* A point cloud sample from ModelNet: ``N = 10000``, ``D = 3``.
* Subsets of MNIST
    - A small subset: ``N = 1000``, ``D = 784``
    - A medium subset: ``N = 10000``, ``D = 784``
    - A large subset: ``N = 50000``, ``D = 784``

Some notes:

* ``bruteforce-sharemem`` is an optimized implementation of ``bruteforce`` on GPU.
* ``kd-tree`` is currently only implemented on CPU.
* ``bruteforce-blas`` conducts matrix multiplication, thus is memory inefficient.
* ``nn-descent`` is an approximate algorithm, and we also report the recall rate of its result.

Results
-------

In this section, we show the runtime and recall rate (where applicable) for the algorithms under various scenarios.

The experiments are run on an Amazon EC2 P3.2xlarge instance. This instance has 8 vCPUs with 61GB RAM, and one Tesla V100 GPU with 16GB RAM. In terms of the environment, we obtain the numbers with DGL==0.7.0(`64d0f3f <https://github.com/dmlc/dgl/commit/64d0f3f3554911ec06d015f1c9659180796adf9a>`_), PyTorch==1.8.1, CUDA==11.1 on Ubuntu 18.04.5 LTS.

* **Mixed Gaussian:**

+---------------------+------------------+-------------------+------------------+------------------+
| Model               | CPU                                  | GPU                                 |
|                     +------------------+-------------------+------------------+------------------+
|                     | K = 8            | K = 64            | K = 8            | K = 64           |
+=====================+==================+===================+==================+==================+
| bruteforce-blas     | 0.010            | 0.011             | 0.002            | 0.003            |
+---------------------+------------------+-------------------+------------------+------------------+
| kd-tree             | 0.004            | 0.006             | n/a              | n/a              |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce          | 0.004            | 0.006             | 0.126            | 0.009            |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce-sharemem | n/a              | n/a               | 0.002            | 0.003            |
+---------------------+------------------+-------------------+------------------+------------------+
| nn-descent          | 0.014 (R: 0.985) | 0.148 (R: 1.000)  | 0.016 (R: 0.973) | 0.077 (R: 1.000) |
+---------------------+------------------+-------------------+------------------+------------------+

* **Point Cloud**

+---------------------+------------------+-------------------+------------------+------------------+
| Model               | CPU                                  | GPU                                 |
|                     +------------------+-------------------+------------------+------------------+
|                     | K = 8            | K = 64            | K = 8            | K = 64           |
+=====================+==================+===================+==================+==================+
| bruteforce-blas     | 0.359            | 0.432             | 0.010            | 0.010            |
+---------------------+------------------+-------------------+------------------+------------------+
| kd-tree             | 0.007            | 0.026             | n/a              | n/a              |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce          | 0.074            | 0.167             | 0.008            | 0.039            |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce-sharemem | n/a              | n/a               | 0.004            | 0.017            |
+---------------------+------------------+-------------------+------------------+------------------+
| nn-descent          | 0.161 (R: 0.977) | 1.345 (R: 0.999)  | 0.086 (R: 0.966) | 0.445 (R: 0.999) |
+---------------------+------------------+-------------------+------------------+------------------+

* **Small MNIST**

+---------------------+------------------+-------------------+------------------+------------------+
| Model               | CPU                                  | GPU                                 |
|                     +------------------+-------------------+------------------+------------------+
|                     | K = 8            | K = 64            | K = 8            | K = 64           |
+=====================+==================+===================+==================+==================+
| bruteforce-blas     | 0.014            | 0.015             | 0.002            | 0.002            |
+---------------------+------------------+-------------------+------------------+------------------+
| kd-tree             | 0.179            | 0.182             | n/a              | n/a              |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce          | 0.173            | 0.228             | 0.123            | 0.170            |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce-sharemem | n/a              | n/a               | 0.045            | 0.054            |
+---------------------+------------------+-------------------+------------------+------------------+
| nn-descent          | 0.060 (R: 0.878) | 1.077 (R: 0.999)  | 0.030 (R: 0.952) | 0.457 (R: 0.999) |
+---------------------+------------------+-------------------+------------------+------------------+

* **Medium MNIST**

+---------------------+------------------+-------------------+------------------+------------------+
| Model               | CPU                                  | GPU                                 |
|                     +------------------+-------------------+------------------+------------------+
|                     | K = 8            | K = 64            | K = 8            | K = 64           |
+=====================+==================+===================+==================+==================+
| bruteforce-blas     | 0.897            | 0.970             | 0.019            | 0.023            |
+---------------------+------------------+-------------------+------------------+------------------+
| kd-tree             | 18.902           | 18.928            | n/a              | n/a              |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce          | 14.495           | 17.652            | 2.058            | 2.588            |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce-sharemem | n/a              | n/a               | 2.257            | 2.524            |
+---------------------+------------------+-------------------+------------------+------------------+
| nn-descent          | 0.804 (R: 0.755) | 14.108 (R: 0.999) | 0.158 (R: 0.900) | 1.794 (R: 0.999) |
+---------------------+------------------+-------------------+------------------+------------------+

* **Large MNIST**

+---------------------+------------------+-------------------+------------------+------------------+
| Model               | CPU                                  | GPU                                 |
|                     +------------------+-------------------+------------------+------------------+
|                     | K = 8            | K = 64            | K = 8            | K = 64           |
+=====================+==================+===================+==================+==================+
| bruteforce-blas     | 21.829           | 22.135            | Out of Memory    | Out of Memory    |
+---------------------+------------------+-------------------+------------------+------------------+
| kd-tree             | 542.688          | 573.379           | n/a              | n/a              |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce          | 373.823          | 432.963           | 10.317           | 12.639           |
+---------------------+------------------+-------------------+------------------+------------------+
| bruteforce-sharemem | n/a              | n/a               | 53.133           | 58.419           |
+---------------------+------------------+-------------------+------------------+------------------+
| nn-descent          | 4.995 (R: 0.658) | 75.487 (R: 0.999) | 1.478 (R: 0.860) | 15.698 (R: 0.999)| 
+---------------------+------------------+-------------------+------------------+------------------+

Conclusion
----------

- As long as you have enough memory, ``bruteforce-blas`` is the default algorithm to go with.
- Specifically, when ``D`` is small and the data is on CPU, ``kd-tree`` is the best algorithm.

