{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LapJAX Tutorial\n",
    "(updated at Oct-12-20223)\n",
    "\n",
    "This tutorial aims to give you a quick understanding of LapJAX. We will cover the following topics:\n",
    "1. How to use LapJAX to accelerate your code simply by changing the import statement.\n",
    "2. How fast LapJAX is compared to stanford methods.\n",
    "3. How to build your custom operators in LapJAX.\n",
    "\n",
    "\n",
    "### A Quick Start\n",
    "To use LapJAX, simply change all `jax` in the import statement to `lapjax`. For example, change\n",
    "```python\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import vmap\n",
    "```\n",
    "to\n",
    "```python\n",
    "import lapjax\n",
    "import lapjax.numpy as jnp\n",
    "from lapjax import vmap\n",
    "```\n",
    "Without any further change, this code runs *exactly the same* as when you use `jax`. This ensures the compatibility of LapJAX under most of situations where laplacian is not needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:LapJAX:When wrapping `jax` modules `tools`, got ImportError:\n",
      "    No module named 'requests'\n",
      "WARNING:LapJAX:This won't affect functions of other modules.\n",
      "WARNING:LapJAX:When wrapping `jax` modules `collect_profile.py`, got ImportError:\n",
      "    This script requires `tensorflow` to be installed.\n",
      "WARNING:LapJAX:This won't affect functions of other modules.\n"
     ]
    }
   ],
   "source": [
    "# WARN: if you are running on the Ampere architecture GPU, e.g. A100/GTX3090,\n",
    "# please make sure to close the TF32 option.\n",
    "# Otherwise, you will lose numerical precisions.\n",
    "import os\n",
    "os.environ['NVIDIA_TF32_OVERRIDE'] = \"0\"\n",
    "\n",
    "# import jax\n",
    "# impoprt jax.numpy as jnp\n",
    "import lapjax as jax\n",
    "import lapjax.numpy as jnp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Case 1: A stanford MLP model\n",
    "Below we consider a stanford MLP network. Assume you have written the model using `jax` as follows. All functions remain unchanged when you use `lapjax`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-10-23 17:16:40.424880: W external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 0 and 3; status: INTERNAL: failed to enable peer access from 0x7f936876b260 to 0x7f93648f6e40: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted\n",
      "2023-10-23 17:16:40.431744: W external/org_tensorflow/tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.cc:63] Unable to enable peer access between GPUs 3 and 0; status: INTERNAL: failed to enable peer access from 0x7f93648f6e40 to 0x7f936876b260: CUDA_ERROR_TOO_MANY_PEERS: peer mapping resources exhausted\n"
     ]
    }
   ],
   "source": [
    "# define the hyperparameters\n",
    "input_dim = 64\n",
    "hidden_dim = 256\n",
    "hidden_layer = 4\n",
    "layer_dims = [input_dim,] + [hidden_dim,] * hidden_layer + [1,]\n",
    "\n",
    "# define init function\n",
    "def init_params(key):\n",
    "    params = []\n",
    "    left_dim = input_dim\n",
    "    for right_dim in layer_dims:\n",
    "        key, subkey = jax.random.split(key)\n",
    "        params.append(jax.random.normal(subkey, (left_dim,right_dim)) * 0.1)\n",
    "        left_dim = right_dim\n",
    "    return params\n",
    "\n",
    "# Define the network\n",
    "def MLP(params,x):\n",
    "    for param in params:\n",
    "        # use lapjax.numpy to construct the function\n",
    "        # This function can take both jax.ndarray and lapjax.LapTuple as input\n",
    "        x = jnp.matmul(x, param)\n",
    "        x = jnp.tanh(x)\n",
    "    return x.reshape(-1)\n",
    "\n",
    "key = jax.random.PRNGKey(42)\n",
    "key, subkey = jax.random.split(key)\n",
    "params = init_params(key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we construct the [pure function](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) to compute laplacian of the model in two ways."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute laplacian through standard jax method\n",
    "def get_laplacian_function_orig(func):\n",
    "    def lap(data):\n",
    "        # assume data.shape = (N,)\n",
    "        # compute hessian\n",
    "        hess = jax.hessian(func)(data)\n",
    "        return jnp.trace(hess,axis1=-1,axis2=-2)\n",
    "    return lap\n",
    "\n",
    "# compute laplacian through lapjax\n",
    "def get_laplacian_function_lapjax(func):\n",
    "    # ATTENTION: all you need to do is a few changes here!\n",
    "    from lapjax import LapTuple\n",
    "    def lap(data):\n",
    "        input_laptuple = LapTuple(data, is_input=True)\n",
    "        output_laptuple = func(input_laptuple)\n",
    "        # LapTuple has value, grad\n",
    "        return output_laptuple.lap\n",
    "    return lap\n",
    "\n",
    "# get laplacian function\n",
    "lap_original = get_laplacian_function_orig(lambda x: MLP(params, x))\n",
    "lap_lapjax = get_laplacian_function_lapjax(lambda x: MLP(params, x))\n",
    "\n",
    "# note that the input and output of `lap_lapjax` are both `jax.numpy.ndarray`\n",
    "# so we can use `jax.vmap` to deal with a batch of data. `jax.jit` also works well\n",
    "vmap_lap_orignal = jax.jit(jax.vmap(lap_original))\n",
    "vmap_lap_lapjax = jax.jit(jax.vmap(lap_lapjax))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now test the precision of lapjax method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.8146973e-06\n"
     ]
    }
   ],
   "source": [
    "batch_size = 1280\n",
    "\n",
    "key, subkey = jax.random.split(key)\n",
    "data = jax.random.normal(subkey,(batch_size, input_dim))\n",
    "\n",
    "orig_results = vmap_lap_orignal(data)\n",
    "lapjax_results = vmap_lap_lapjax(data)\n",
    "\n",
    "# the maxium difference is a standard float32 numerical error\n",
    "print(jnp.max(jnp.abs(orig_results-lapjax_results)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now test the acceleration of lapjax method. Notice that for fully connected networks, there is not sparsity acceleration, so the speedup ratio is approximately 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time of hessian-trace: 1.0552332401275635\n",
      "time of forward laplacian: 0.6265256404876709\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "def compute_time(vfunc, key, batch_size, input_dim, iterations=100):\n",
    "    key, subkey = jax.random.split(key)\n",
    "    data_pool = jax.random.normal(subkey,(iterations*2, batch_size, input_dim))\n",
    "    \n",
    "    data_pool1, data_pool2 = jnp.split(data_pool, 2)\n",
    "\n",
    "    # warm up to avoid the cache problem\n",
    "    for data in data_pool1:\n",
    "        val = vfunc(data)\n",
    "\n",
    "    start_time = time.time()\n",
    "    for data in data_pool2:\n",
    "        val = vfunc(data)\n",
    "    end_time = time.time()\n",
    "    return val, end_time - start_time\n",
    "\n",
    "val, duration = compute_time(vmap_lap_orignal, key, batch_size, input_dim)\n",
    "print('time of hessian-trace:', duration)\n",
    "\n",
    "val, duration = compute_time(vmap_lap_lapjax, key, batch_size, input_dim)\n",
    "# forward laplacian is roughly 2 times faster than the original method\n",
    "print('time of forward laplacian:', duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Case 2: A Slater-Determinants based model\n",
    "Below we consider a Slater-Determinants based model that is typically used to represent wave functions. As you will see, the model is very sparse, and thus the acceleration of lapjax is significant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CASE2: a Slater-Determinants like wave functions\n",
    "# In this case, we could leverage the derivative sparsity to\n",
    "# achieve over a magnitude speed-up.\n",
    "\n",
    "# construct input params\n",
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "n_elec = 16 # number of electrons\n",
    "input_dim = 3 # the dimension of electron position\n",
    "hidden_dim = 256\n",
    "hidden_layer = 2\n",
    "layer_dims = [input_dim,] + [hidden_dim,] * hidden_layer + [n_elec,]\n",
    "\n",
    "def init_params(key):\n",
    "    params = []\n",
    "    left_dim = input_dim\n",
    "    for right_dim in layer_dims:\n",
    "        key, subkey = jax.random.split(key)\n",
    "        params.append(jax.random.normal(subkey, (left_dim,right_dim)) * 0.1)\n",
    "        left_dim = right_dim\n",
    "    return params\n",
    "\n",
    "# construct the wave functions. \n",
    "def slater_determinants(params, x):\n",
    "\n",
    "    # x.shape = (n_elec * input_dim,)\n",
    "    x = x.reshape(n_elec, input_dim)\n",
    "    for param in params:\n",
    "        # Each electron is processed by the same MLP function\n",
    "        x = jnp.matmul(x, param)\n",
    "        x = jnp.tanh(x)\n",
    "\n",
    "    x = x + jnp.eye(x.shape[0])\n",
    "\n",
    "    _, lnpsi = jnp.linalg.slogdet(x)\n",
    "    return lnpsi\n",
    "\n",
    "key, subkey = jax.random.split(key)\n",
    "params_sd = init_params(key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar, we construct the pure function to compute the kinetic energy in two ways. In Variational Monte Carlo, we should compute the local kinetic energy, which is defined as: \n",
    "$$\n",
    "E_k = \\frac{-0.5 \\times \\nabla^2 \\psi(\\mathbf x)}{\\psi (\\mathbf x)} = -0.5 \\times \\nabla^2 \\ln \\psi (\\mathbf x) - 0.5 \\times (\\nabla \\ln \\psi (\\mathbf x))^2.\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kinetic_function_orig(func):\n",
    "    def kinetic(data):\n",
    "        grad = jax.grad(func)(data)\n",
    "        hess = jax.hessian(func)(data)\n",
    "        return -0.5 * (jnp.trace(hess) + jnp.sum(grad ** 2))\n",
    "        \n",
    "    return kinetic\n",
    "\n",
    "def get_kinetic_function_lapjax(func):\n",
    "    from lapjax import LapTuple\n",
    "    def kinetic(data):\n",
    "        input_laptuple = LapTuple(data, is_input=True)\n",
    "        output_laptuple = func(input_laptuple)\n",
    "\n",
    "        # A Laptupe stores both gradient and laplacian information,\n",
    "        # so we do not need to compute gradient again.\n",
    "        return -0.5 * output_laptuple.lap - 0.5 * jnp.sum(output_laptuple.grad**2)\n",
    "\n",
    "    return kinetic\n",
    "\n",
    "# get kinetic function\n",
    "ke_original = get_kinetic_function_orig(lambda x: slater_determinants(params_sd, x))\n",
    "ke_lapjax = get_kinetic_function_lapjax(lambda x: slater_determinants(params_sd, x))\n",
    "\n",
    "vmap_ke_orignal = jax.jit(jax.vmap(ke_original))\n",
    "vmap_ke_lapjax = jax.jit(jax.vmap(ke_lapjax))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now test the precision of lapjax method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.7683716e-07\n"
     ]
    }
   ],
   "source": [
    "# CASE2: precision test\n",
    "batch_size = 128\n",
    "\n",
    "key, subkey = jax.random.split(key)\n",
    "data = jax.random.normal(subkey,(batch_size, input_dim*n_elec))\n",
    "\n",
    "orig_results = vmap_ke_orignal(data)\n",
    "lapjax_results = vmap_ke_lapjax(data)\n",
    "\n",
    "# the maxium difference is a standard float32 numerical error\n",
    "print(jnp.max(jnp.abs(orig_results-lapjax_results)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now test the acceleration of lapjax method. The efficiency improvement is significant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time of hessian-trace: 0.6880514621734619\n",
      "time of forward laplacian: 0.12378358840942383\n"
     ]
    }
   ],
   "source": [
    "val, duration = compute_time(vmap_ke_orignal, \n",
    "                             key, batch_size, input_dim * n_elec)\n",
    "print('time of hessian-trace:', duration)\n",
    "\n",
    "val, duration = compute_time(vmap_ke_lapjax, \n",
    "                             key, batch_size, input_dim * n_elec)\n",
    "# forward laplacian is roughly 2 times faster than the original method\n",
    "print('time of forward laplacian:', duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Case 3: Customize Operators\n",
    "We understand that the operators you need may not have been wrapped by lapjax yet. In this case, you can easily wrap operators in `jax` as you want. Assume we want to use `jax.numpy.isnan` in our model, which does not support a `LapTuple` input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lapjax encounters unwrapped function 'isnan'.\n",
      "Please consider using other functions or wrap it yourself.\n",
      "You can refer to README for more information about customized wrap.\n"
     ]
    }
   ],
   "source": [
    "from lapjax import LapTuple\n",
    "import lapjax.numpy as jnp\n",
    "lap = LapTuple(jnp.eye(4), is_input=True) / jnp.eye(4)\n",
    "try:\n",
    "    print(jnp.isnan(lap))\n",
    "except Exception as e:  # should see the unwarpped error.\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To wrap a `jax` function, e.g., `f`, we need to:\n",
    "1. specify what classes `f` belongs. For instance, `jax.numpy.exp2` belongs to `FElement`, as the operator is applied to each element of the input.\n",
    "2. bind `f` to corresponding class, and write a customized function for `LapTuple` (only if needed).\n",
    "\n",
    "For the `jax.numpy.isnan` case, we need to judge whether the input is nan or not, and there is no gradient or laplacian value (or you can regard it as 0). When the output of a function `f` should only contain arrays with zero gradient and laplacian, `f` should belong to `FConstruction` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully bind function 'isnan' to FType.CONSTRUCTION.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DeviceArray([[False,  True,  True,  True],\n",
       "             [ True, False,  True,  True],\n",
       "             [ True,  True, False,  True],\n",
       "             [ True,  True,  True, False]], dtype=bool)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from lapjax import custom_wrap, FType\n",
    "custom_wrap(jnp.isnan, FType.CONSTRUCTION)\n",
    "\n",
    "jnp.isnan(lap)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, we can wrap `jax.numpy.isinf` and `jax.numpy.isfinite` as above. To wrap `jax.numpy.exp2`, which belongs to `FElement` class, we can simply do"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:LapJAX:Notice that if custom_type is `FLinear`, the LapTuple might loss the sparsity and cause inefficiency.\n",
      "You can customize the function yourself and bind to `CUSTOMIZED`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Lapjax encounters unwrapped function 'nansum'.\n",
      "Please consider using other functions or wrap it yourself.\n",
      "You can refer to README for more information about customized wrap.\n",
      "\n",
      "Now we wrap it.\n",
      "\n",
      "Successfully bind function 'nansum' to FType.LINEAR.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([1., 1., 1., 1.], dtype=float32), (16, 4))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "f = lambda x: jnp.log(jnp.nansum(jnp.exp(x), axis=-1))\n",
    "try:\n",
    "    f(lap)\n",
    "except Exception as e:\n",
    "    print(e)\n",
    "print(\"\\nNow we wrap it.\\n\")\n",
    "from lapjax import custom_wrap, FType\n",
    "custom_wrap(jnp.nansum, FType.LINEAR)\n",
    "output = f(lap)\n",
    "output.value, output.grad.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should have noticed the output \"Notice that if custom_type is `FLinear`, you might loss the sparsity. Please consider customize the function and bind to `CUSTOMIZED`.\" This is because `lapjax` leverages the sparsity of arrays to accelerate, and binding a function to `FLiner` class without curating the sparsity of output LapTuples will result in a loss of sparsity. In this case, you should consider curating the sparsity of output LapTuples yourself, and bind the function to `CUSTOMIZED` class. For example, below is the way to customize `jax.numpy.nansum` carefully."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully bind function 'nansum' to FType.CUSTOMIZED.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([1., 1., 1., 1.], dtype=float32), (4, 4))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cst_nansum(*args, **kwargs):    # same inputs as jnp.nansum\n",
    "    array: LapTuple = args[0]\n",
    "    # Used already wrapped functions to compose nansum.\n",
    "    # Notice that we have wrapped isnan before. \n",
    "    array = jnp.where(jnp.isnan(array), 0, array) # mask nan to 0.\n",
    "    args = (array,) + args[1:]\n",
    "\n",
    "    return jnp.sum(*args, **kwargs)\n",
    "custom_wrap(jnp.nansum, FType.CUSTOMIZED, cst_f=cst_nansum, overwrite=True)\n",
    "output = f(lap)\n",
    "output.value, output.grad.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you compare two outputs above, you will find that the second cell have a smaller gradient size. This is because `lapjax.numpy.sum` has curated the sparsity such that the output LapTuple only contains truly non-zero gradient values. Let's compare their efficiencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully bind function 'nansum' to FType.LINEAR.\n",
      "time of no sparsity wrap: 0.9739506244659424\n",
      "Successfully bind function 'nansum' to FType.CUSTOMIZED.\n",
      "time of sparsity wrap: 0.012285947799682617\n"
     ]
    }
   ],
   "source": [
    "input_dim = 100 # shape = (input_dim, input_dim)\n",
    "batch_size = 16\n",
    "\n",
    "# bind to `FLINEAR` type\n",
    "custom_wrap(jnp.nansum, FType.LINEAR, overwrite=True)\n",
    "def f(x):\n",
    "    sq = x.reshape(input_dim, input_dim) / (1-jnp.eye(input_dim))\n",
    "    return jnp.mean(jnp.log(jnp.nansum(jnp.exp(sq), axis=-1)))\n",
    "lap_original = get_laplacian_function_lapjax(f)\n",
    "vmap_lap_orignal = jax.jit(jax.vmap(lap_original))\n",
    "\n",
    "key, subkey = jax.random.split(key)\n",
    "val, duration = compute_time(vmap_lap_orignal, \n",
    "                             subkey, batch_size, input_dim ** 2)\n",
    "print('time of no sparsity wrap:', duration)\n",
    "\n",
    "# bind to `FCUSTOMIZED` type\n",
    "custom_wrap(jnp.nansum, FType.CUSTOMIZED, cst_f=cst_nansum, overwrite=True)\n",
    "lap_customized = get_laplacian_function_lapjax(f)\n",
    "vmap_lap_orignal = jax.jit(jax.vmap(lap_customized))\n",
    "\n",
    "key, subkey = jax.random.split(subkey)\n",
    "val, duration = compute_time(vmap_lap_orignal, \n",
    "                             key, batch_size, input_dim ** 2)\n",
    "print('time of sparsity wrap:', duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Summary\n",
    "It should be always kept in mind that the major acceleration of `lapjax` comes from the sparsity. Thus, try to use existing functions to write your (minimal) laplacian computation function, and try to wrap to `FELEMENT`, `FCONSTRUCTION`, `FMERGING`, and `FCUSTOMIZED` as much as possible.\n",
    "\n",
    "If you think some functions are commonly used and should be wrapped in `lapjax`, please feel free to contact us or raise issues. Enjoy!"
   ]
  }
 ],
 "metadata": {
  "fileId": "ceff7f73-eb57-4e47-aeaf-b87d146db090",
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
