---
layout: post
title: Stochastic Autograd
authors: If you know them, let us know
tags: [auto-grad, backpropagation, tutorial]  # These should be the relevant areas related to your blog post
---

- [Background](#background)
  - [Manual Calculations](#manual-calculations)
  - [Using Programs](#using-programs)
- [Automatic Differentiation](#automatic-differentiation)
  - [Chain Rule](#chain-rule)
  - [Forward Accumulation](#forward-accumulation)
  - [Reverse Accumulation](#reverse-accumulation)
- [Automatic Differentiation in Practice](#automatic-differentiation-in-practice)
  - [Jacobians](#jacobians)
  - [Jacobian Calculation as Computational Graph](#jacobian-calculation-as-computational-graph)
  - [First-order Optimization Methods](#first-order-optimization-methods)
- [Here Comes the Challenger](#here-comes-the-challenger)
  - [Implementation](#implementation)
  - [Questions](#questions)
- [Akcnowledgments](#akcnowledgments)
- [References](#references)

Background
----------

Pioneered by the [seminal work of Rumelhart and Hinton in 1986](https://www.nature.com/articles/323533a0), the majority of current machine learning optimization methods use derivatives and hence there is a pressing need for their efficient calculation.

### Manual Calculations

>Implementing backpropagation by hand is like programming in assembly language. You will probably never do it, but it is important for having a mental model of how everything works. $-$ Roger Grosse

Time to admit something: I knew about neural networks since 2008, but it wasn't until 2020 when I could implement it by hand (!!). It's funny how it took a blink of second to realize how easy it was, and I couldn't agree more with Prof. Grosse about its importance for having a _mental model_ of neural networks training. Call it a good (or bad) luck of my students that since then, I **enforced** them to do backpropagation by hand.

But, we need to caution ourselves against _milking_ the (manual) backpropagation too much and ultimately have to find some elegant computational solutions. While, we are lucky to live (_grown up_) amongst libraries like PyTorch, JAX, etc., unfortunately, most of the early ML researchers/scientists (for example, [Bottou, 1998 for Stochastic Gradient Descent](https://leon.bottou.org/publications/pdf/online-1998.pdf)) didn't have this luxury and had to go through a slow, laborious (and not guaranteed to be error-free) process of calculating the _analytical_ derivatives. 

### Using Programs

Programming-based solutions obviously save some effort and time, but they have their own share of errors. We can categorize them into three paradigms:

1.  **Symbolic** differentiation
2.  **Numeric** differentiation
3.  **Auto** differentiation

Both Symbolic and Numeric differentiation methods are prone to errors, like:

*   Calculating higher-derivatives is tricky
*   Numeric differentiation’s use of discretization results in loss of accuracy
*   Symbolic differentiation leads to inefficient code. Actually, Symbolic differentiation is just a way of computing a derivative in a human-readable form, not evaluating it.
*   Both are slow at calculating the partial derivatives (a key feature of **gradient-based optimization algorithms**)

Automatic Differentiation
-------------------------

Automatic differentiation addresses all the issues above and is a key feature of the modern ML/DL libraries. Actually, the last (sweeping) statement fails to highlight the fact that merely a decade ago, Autodiff was something alien to the ML community $-$ it may sound bizarre today, but even this [less than 15 years old post](https://justindomke.wordpress.com/2009/02/17/automatic-differentiation-the-most-criminally-underused-tool-in-the-potential-machine-learning-toolbox/) was trying to clarify the fundamental differences between numeric and symbolic differentiation.

### Chain Rule

Autodiff centers around the **Chain rule** concept $-$ the fundamental rule in Calculus $-$ to calculate derivatives of the composed functions.

For example,

\\\[y = 2x^2 +4 ; x = 3w\\\]

Obviously differentiating _y_ with respect to _w_ (i.e.$\\frac{dy}{dw}$) is not possible directly. So it will be calculated indirectly using the Chain rule:

\\\[\\frac {dy}{dw} = \\frac {dy}{dx} \\times \\frac {dx}{dw} = 36w\\\]

There are a couple of ways to calculate the products using the chain rule $-$ depending on whether we go from inputs to outputs or the other way around.

### Forward Accumulation

In forward accumulation, we fix the independent variable and compute gradients recursively (similar to how we calculated $\\frac{dy}{dw}$ above).

![Forward Accumulation (taken from Colah blog)](/public/images/2021-09-01-stochastic-autograd/chain-forward-greek.png)

### Reverse Accumulation

Reverse accumulation is the other name for backpropagation. It is implemented in all the major frameworks like **PyTorch** or **Tensorflow** (JAX is even better at **providing both options** at the user’s disposal).

![Reverse Accumulation (taken from Colah blog)](/public/images/2021-09-01-stochastic-autograd/chain-backward-greek.png)

> Whether to go for forward or reverse accumulation usually depends on the number of features, but reverse accumulation is the _de facto_ method in deep learning.

* * *

Automatic Differentiation in Practice
-------------------------------------

So far, **AutoGrad** sounds like a great choice. Let’s explore it bit more from the Neural Networks point of view.

### Jacobians

If we recall, a Jacobian matrix is a counterpart of Gradient for vector-valued functions and defined as:

\\\[J = \\begin{bmatrix} \\dfrac{\\partial \\mathbf{f}}{\\partial x\_1} & \\cdots & \\dfrac{\\partial \\mathbf{f}}{\\partial x\_n} \\end{bmatrix} = \\begin{bmatrix} \\nabla^\\mathsf{T} f\_1 \\\\ \\vdots \\\\ \\nabla^\\mathsf{T} f\_m \\end{bmatrix} = \\begin{bmatrix} \\dfrac{\\partial f\_1}{\\partial x\_1} & \\cdots & \\dfrac{\\partial f\_1}{\\partial x\_n}\\\\ \\vdots & \\ddots & \\vdots\\\\ \\dfrac{\\partial f\_m}{\\partial x\_1} & \\cdots & \\dfrac{\\partial f\_m}{\\partial x\_n} \\end{bmatrix}\\\]

Since, neural networks are trained in a vectorized way as well, it can be represented as: $f:\\mathcal R^n\\to\\mathcal R^m$ and consequently, its **differential** (Jacobian) would be:

\\\[f^{'}:\\mathcal R^n \\to \\mathcal R^{m\\times n}\\\]

### Jacobian Calculation as Computational Graph

Remember, we are using AutoGrad, so there must be a way to find out the _nightmarish_ Jacobians (not [_that Jacobins_](https://en.wikipedia.org/wiki/Jacobin)). A neat approach is to use **Linearized Computational Graph (LCG)**, a DAG having vertices as the variables. Figure below (taken from the paper itself) illustrates it better:

![LCG](/public/images/2021-09-01-stochastic-autograd/paper-fig1.png)

Since, a Jacobian is calculated by sum over all paths in the graph, it’s calculation does take time (*did someone spell Dynamic Programming here?*) and can be done in either forward and reverse modes (for example, **`jacfwd()`** and **`jacrev()`** in JAX).

### First-order Optimization Methods

Right from the [Robbins, Monro (1951)](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-22/issue-3/A-Stochastic-Approximation-Method/10.1214/aoms/1177729586.full), there has been a realization of stochastic optimization for better training of neural networks. It has been even [proven already](https://arxiv.org/abs/1902.00247v2) that SGD can manage to escape the saddle points as well.

It leads to apparently simple, but every thought-provoking question: _Why spend resources on **exact gradients** when we’re going to use **stochastic optimization?**_

Here Comes the Challenger
-------------------------

This excellent question was raised by the authors of the paper titled **[Randomized Automatic Differentiation](https://openreview.net/pdf?id=xpx9zj7CUlY)** (ICLR Oral!!! $-$ not your everyday NeurIPS paper).

Well jokes apart, I came across this impressive paper while learning JAX and its AutoGrad. To be honest, I am still surprised by the authors being either unaware of JAX (found no mention of JAX, while Tensorflow and PyTorch found their honorary mentions) or being aware but just not caring to use it (which is an even bigger surprise).

Choice of implemention framework aside, the paper addresses one of the fundamental issues in ML: saving the resources by using inexact gradients (keeping them unbiased).

### Implementation

Implementation is pretty simple and quoting here in author’s own words:

> Our main idea is to consider an unbiased estimator $\\mathcal J\_θ \[f\]$ such that $EJ\_θ \[f\] = \\mathcal J\_θ\[f\]$ which allows us to save memory required for reverse-mode AD. Our approach is to **determine a sparse (but random) linearized computational graph during the forward pass such that reverse-mode AD applied on the sparse graph yields an unbiased estimate of the true gradient**. Note that the original computational graph is used for the forward pass, and randomization is used to determine a LCG to use for the backward pass in place of the original computation graph. We may then decrease memory costs by storing the sparse LCG directly or storing intermediate variables required to compute the sparse LCG.

### Questions

Since its not meant to be a review or any such thing, I would not go into details of the paper further. Though, there are some questions in my mind after reading it:

*   How will _determining the optimal path_ in LCG affect the overall performance? How long does it take on average?
*   Why only MNIST and CIFAR? Did authors try other datasets? How well can it scale?
*   Do we have to perform the exhaustive search for hyperparameters in the future as well or that $10^{-4}$ will do more or less?
*   Will a better PRNG (like threefry counter used in JAX) result in better results?
*   Can we compare the performance with the existing attempts (like Sun _et al._ or Gomez _et al._)?

(_This blog is bit incomplete due to time constraints_)

Akcnowledgments
----------
We are deeply grateful to Cristopher Olah for his kind permission to use the images from his awesome blog.

References
----------

1.  [Calculus on Computational Graphs: Backpropagation](https://colah.github.io/posts/2015-08-Backprop/) $-$ a recommended read (both images on forward and reverse accumulation are taken from the blog)
    
2.  [JAX - AutoDiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
    
3.  [Automatic Differentiation in Machine Learning: a Survey](https://arxiv.org/pdf/1502.05767v4.pdf)