{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1b560c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88273fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import scipy\n",
    "import scipy.special\n",
    "import scipy.optimize\n",
    "import scipy.stats\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm.auto import trange, tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302ff44d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gelu(x):\n",
    "    return .5 * x * (1 + scipy.special.erf(x / np.sqrt(2)))\n",
    "\n",
    "\n",
    "def gelu_deriv(x):\n",
    "    return .5 * scipy.special.erf(x / np.sqrt(2)) + x * np.exp(-x**2 / 2) / np.sqrt(2 * np.pi) + .5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8f4bb38",
   "metadata": {},
   "source": [
    "# Approximation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99e660c8",
   "metadata": {},
   "source": [
    "### Dynamic Programming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e450470d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_y_opt(int_w, int_f_w):\n",
    "    num = int_f_w[None, :] - int_f_w[:, None]\n",
    "\n",
    "    den = int_w[None, :] - int_w[:, None]\n",
    "    np.fill_diagonal(den, 1.)\n",
    "    den[den == 0] = 1.\n",
    "\n",
    "    return num / den\n",
    "\n",
    "\n",
    "def calc_trainsition_matrix(int_w, int_f2_w, y_opt):\n",
    "    translation_matrix = int_f2_w[None, :] - int_f2_w[:, None] - y_opt**2 * (int_w[None, :] - int_w[:, None])\n",
    "    translation_matrix = np.where(np.tri(len(int_w), k=-1, dtype=np.bool).T, translation_matrix, np.inf)\n",
    "    return translation_matrix\n",
    "    \n",
    "    \n",
    "def restore_answer(frm):\n",
    "    current = len(frm[-1]) - 1\n",
    "    points = [current]\n",
    "    for frm_curr in frm[::-1]:\n",
    "        current = frm_curr[current]\n",
    "        points.append(current)\n",
    "    points = points[::-1]\n",
    "    \n",
    "    assert points[0] == 0\n",
    "    assert points[-1] == len(frm[-1]) - 1\n",
    "    \n",
    "    return points\n",
    "\n",
    "\n",
    "def dynamic_programming(int_w, int_f_w, int_f2_w, xs, k: int):\n",
    "    y_opt = calc_y_opt(int_w, int_f_w)\n",
    "    transition_matrix = calc_trainsition_matrix(int_w, int_f2_w, y_opt)\n",
    "\n",
    "    def one_step(dp):\n",
    "        dp_new = dp[:, None] + transition_matrix\n",
    "        frm = np.argmin(dp_new, axis=0)\n",
    "        return np.min(dp_new, axis=0), frm\n",
    "\n",
    "    dp = np.full(len(int_w), np.inf)\n",
    "    dp[0] = 0\n",
    "    frm = []\n",
    "\n",
    "    for _ in trange(k):\n",
    "        dp, from_curr = one_step(dp)\n",
    "        frm.append(from_curr)\n",
    "\n",
    "    points = restore_answer(frm)\n",
    "    \n",
    "    return xs[points], y_opt[points[:-1], points[1:]]\n",
    "\n",
    "\n",
    "def numerical_integral(f_values, delta_x):\n",
    "    return np.cumsum(f_values) * delta_x\n",
    "\n",
    "\n",
    "def uniform_weight(n_segments, integral_n_coeff, L: float, R: float, n_bits: int, act, act_deriv):\n",
    "    \"\"\"\n",
    "    \\int_L^R (act_deriv(x) - q(x | xs, ys))^2 dx -> min\n",
    "    \n",
    "    n_segments: number of discretization points\n",
    "    integral_n_coeff: how many points to use inside each interval to perform numerical integration\n",
    "    L: left bound \n",
    "    R: right bound\n",
    "    n_bits: number of bits in piecewise-constant approximation (2**n_bits is number of intervals)\n",
    "    act: callable representing activation function\n",
    "    act_deriv: callable representing derivative of activation function\n",
    "    \"\"\"\n",
    "    \n",
    "    xs = np.linspace(L, R, n_segments + 1)\n",
    "\n",
    "    int_w = xs - xs[0]\n",
    "    int_f_w = act(xs) - act(xs[0])\n",
    "\n",
    "    xs_int = np.linspace(L, R, (n_segments + 1) * (integral_n_coeff + 1) - integral_n_coeff)\n",
    "    int_f2_w = numerical_integral(act_deriv(xs_int)**2, xs_int[1] - xs_int[0])[::integral_n_coeff + 1]\n",
    "    \n",
    "    return dynamic_programming(int_w, int_f_w, int_f2_w, xs, 2**n_bits)\n",
    "\n",
    "\n",
    "def all_numerical(n_segments, integral_n_coeff, L: float, R: float, n_bits: int, act_deriv, w):\n",
    "    \"\"\"\n",
    "    int (act_deriv(x) - q(x | xs, ys))^2 w(x) dx -> min\n",
    "    \n",
    "    \n",
    "    n_segments: number of discretization points\n",
    "    integral_n_coeff: how many points to use inside each interval to perform numerical integration\n",
    "    L: left bound \n",
    "    R: right bound\n",
    "    n_bits: number of bits in piecewise-constant approximation (2**n_bits is number of intervals)\n",
    "    act: callable representing activation function\n",
    "    act_deriv: callable representing derivative of activation function\n",
    "    \"\"\"\n",
    "    \n",
    "    xs = np.linspace(L, R, n_segments + 1)\n",
    "    xs_int = np.linspace(L, R, (n_segments + 1) * (integral_n_coeff + 1) - integral_n_coeff)\n",
    "    delta_x = xs_int[1] - xs_int[0]\n",
    "    \n",
    "    int_w = numerical_integral(w(xs_int), delta_x)[::integral_n_coeff + 1]\n",
    "    int_f_w = numerical_integral(act_deriv(xs_int) * w(xs_int), delta_x)[::integral_n_coeff + 1]\n",
    "    int_f2_w = numerical_integral(act_deriv(xs_int)**2 * w(xs_int), delta_x)[::integral_n_coeff + 1]\n",
    "    \n",
    "    return dynamic_programming(int_w, int_f_w, int_f2_w, xs, 2**n_bits)\n",
    "\n",
    "\n",
    "N_SEGMENTS = 2**12\n",
    "INTEGRAL_N_COEFF = 2**10\n",
    "L = -10\n",
    "R = 10\n",
    "N_BITS = 5\n",
    "\n",
    "xs, ys = uniform_weight(N_SEGMENTS, INTEGRAL_N_COEFF, L, R, N_BITS, gelu, gelu_deriv)\n",
    "# xs, ys = all_numerical(N_SEGMENTS, INTEGRAL_N_COEFF, L, R, N_BITS, gelu_deriv, lambda x: scipy.stats.norm.pdf(x * .5))\n",
    "\n",
    "plt.plot(np.linspace(L, R, 2**10), gelu_deriv(np.linspace(L, R, 2**10)))\n",
    "for x1, x2, y in zip(xs[:-1], xs[1:], ys):\n",
    "    plt.plot([x1, x2], [y, y])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
