{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import psdr, psdr.demos\n",
    "from scipy.linalg import orth\n",
    "from psdr import local_linear_grads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build a test function\n",
    "np.random.seed(0)\n",
    "m = 5\n",
    "Q = orth(np.random.randn(m,m))\n",
    "lam = np.zeros(m)\n",
    "lam = np.linspace(1,10, m)**(-2)\n",
    "A = Q @ np.diag(lam) @ Q.T\n",
    "fun = psdr.demos.QuadraticFunction(quad = A)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Estimating Gradients from Samples Using a Perplexity-based Bandwidth\n",
    "\n",
    "In many cases we seek the gradient of a function $\\nabla f(\\mathbf{x})$, but only have access to samples $f(\\mathbf{x}_i)$ where the sample locations $\\lbrace\\mathbf{x}_i\\rbrace_{i=1}^N$ are pre-determined and may be arbitrary; i.e., we cannot employ the finite difference approaches that are natural if we are free to choose $\\mathbf{x}_i$.  In this situation, we will seek to estimate the gradient at $\\mathbf{x}_i$ using local linear models centered around $\\mathbf{x}_i$ motivated by its use in the \"Outer Product Gradient\" of Hardle and Stoker (1989).  This same approach appears in earlier texts; see, e.g., \"Local Polynomial Modelling and Its Applications\" by Jianqing Fan and Irene Gijbels Sec. 2.3.\n",
    "\n",
    "The basic premise is as follows.  Suppose we have a kernel function, say $k_\\beta(\\mathbf{x}_1, \\mathbf{x}_2) = e^{-\\beta \\|\\mathbf{x}_1 - \\mathbf{x}_2\\|_2^2}$; we then solve a weighted linear system to construct the local linear model $g(\\mathbf{x}) = a_0 + \\mathbf{a}^\\top \\mathbf{x}$\n",
    "around $\\mathbf{x}_j$:\n",
    "$$ \n",
    "\\min_{a_0\\in \\mathbb{R}, \\mathbf{a}\\in \\mathbb{R}^m}\n",
    "\\sum_{i=1}^N [(a_0 + \\mathbf{a}^\\top \\mathbf{x}_i) - f(\\mathbf{x}_i)]^2 k_\\beta(\\mathbf{x}_i, \\mathbf{x}_j).\n",
    "$$\n",
    "## The Importance of Bandwidth\n",
    "The choice of the kernel and its bandwidth $\\beta$ play a critical role in the efficacy of this approach. We do not want the bandwidth to be too small as then all the weights will be small yielding an ill-conditioned system; nor do we want the bandwidth to be too large as then the weights are all effectively one and the resulting global linear model will ignore the nonlinearlities of $f$.  The figure below shows the effect of this Goldilocks effect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = fun.domain.sample(1000)\n",
    "fX = fun(X)\n",
    "grads_true = fun.grad(X)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "xx = np.linspace(0,1, 100)\n",
    "for i, bw in enumerate([2, 20, 200]):\n",
    "    grads = local_linear_grads(X, fX, bandwidth = bw)\n",
    "    err = np.max(np.abs(grads - grads_true), axis = 1)\n",
    "    ax.hist(err, xx, label = 'bw=%g' % bw, alpha = 0.5)\n",
    "ax.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are a variety of approaches for computing bandwidth including rule of thumb estimates and cross-validation.  Here we argue for an approach based on Perplexity.\n",
    "\n",
    "## Perplexitity Based Bandwidth\n",
    "In algorithms like the SNE and t-SNE, an important step in processing the original high dimensional similarity between two points is to pick a bandwidth at each sample such that the perplexity---the entropy associated with the neigboring points---has a particular value. This ensures that areas of high and low sample density are more fairly treated.  Our approach here is to compute a per-sample bandwidth such that the perplexity is $m+1$.  Although this increases computation time compared to the fixed bandwidth approach with a value taken from Xia 2007 (see Li 18, eq. 11.5), it yields fare more accurate gradient estimates.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: Random Data\n",
    "With purely random data, we should expect that with the optimum bandwidth should be roughly the same for each sample point. Here we see that if we avoid the heuristic proposed by Xia, we can obtain answers comparable to the perplexity based approach."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = fun.domain.sample(2000)\n",
    "fX = fun(X)\n",
    "grads_true = fun.grad(X)\n",
    "\n",
    "print(\"OPG\")\n",
    "%time grads1 = local_linear_grads(X, fX, bandwidth = 'xia')\n",
    "grads2 = local_linear_grads(X, fX, bandwidth = 20)\n",
    "print(\"Perplexity OPG\")\n",
    "%time grads3 = local_linear_grads(X, fX, perplexity = m+1)\n",
    "err1 = np.max(np.abs(grads1 - grads_true), axis = 1)\n",
    "err2 = np.max(np.abs(grads2 - grads_true), axis = 1)\n",
    "err3 = np.max(np.abs(grads3 - grads_true), axis = 1)\n",
    "\n",
    "fig, axes = plt.subplots(1,2, figsize = (14,7))\n",
    "ax = axes[0]\n",
    "xx = np.linspace(0, max(np.max(err1), np.max(err2)), 50)\n",
    "ax.hist(err1, xx, alpha = 0.5, label = 'opg')\n",
    "ax.hist(err2, xx, alpha = 0.5, label = 'opg opt bw');\n",
    "ax.hist(err3, xx, alpha = 0.5, label = 'popg');\n",
    "ax.set_xlabel('error, inf-norm')\n",
    "ax.set_ylabel('density')\n",
    "ax.legend();\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(np.sort(err1), label = 'opg');\n",
    "ax.plot(np.sort(err2), label = 'opg opt bw')\n",
    "ax.plot(np.sort(err3), label = 'popg')\n",
    "ax.set_yscale('log')\n",
    "ax.legend()\n",
    "ax.set_xlabel('index')\n",
    "ax.set_ylabel('error inf-norm');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example: Mixed Data\n",
    "In this example, data is drawn from two distributions to reflect the challenges inherent when we have data that doesn't emerge from a simple random sample approach.  Here, the static bandwidth approach breaks down."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X1 = fun.domain.sample(1000)\n",
    "# We place this off-center so that gradients are large there\n",
    "X2 = 0.1*fun.domain.sample(1000) + 0.5*np.ones(len(fun.domain))\n",
    "X = np.vstack([X1,X2])\n",
    "fX = fun(X)\n",
    "grads_true = fun.grad(X)\n",
    "\n",
    "print(\"OPG\")\n",
    "%time grads1 = local_linear_grads(X, fX, bandwidth = 'xia')\n",
    "grads2 = local_linear_grads(X, fX, bandwidth = 20)\n",
    "print(\"Perplexity OPG\")\n",
    "%time grads3 = local_linear_grads(X, fX, perplexity = m+1)\n",
    "err1 = np.max(np.abs(grads1 - grads_true), axis = 1)\n",
    "err2 = np.max(np.abs(grads2 - grads_true), axis = 1)\n",
    "err3 = np.max(np.abs(grads3 - grads_true), axis = 1)\n",
    "\n",
    "fig, axes = plt.subplots(1,2, figsize =(14,7))\n",
    "ax = axes[0]\n",
    "xx = np.linspace(0, max(np.max(err1), np.max(err2)), 200)\n",
    "ax.hist(err1, xx, alpha = 0.5, label = 'opg')\n",
    "ax.hist(err2, xx, alpha = 0.5, label = 'opg opt bw');\n",
    "ax.hist(err3, xx, alpha = 0.5, label = 'popg');\n",
    "ax.set_xlabel('error, inf-norm')\n",
    "ax.set_ylabel('density')\n",
    "ax.legend();\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(np.sort(err1), label = 'opg');\n",
    "ax.plot(np.sort(err2), label = 'opg opt bw')\n",
    "ax.plot(np.sort(err3), label = 'popg')\n",
    "ax.set_yscale('log')\n",
    "ax.legend()\n",
    "ax.set_xlabel('index')\n",
    "ax.set_ylabel('error inf-norm');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
