{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e348d5ee",
   "metadata": {
    "id": "fLuHW0FhrVqC"
   },
   "outputs": [],
   "source": [
    "import torch as T\n",
    "import torch.utils.benchmark as benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d12e74f4",
   "metadata": {
    "id": "OfSpSFdhsPMl"
   },
   "outputs": [],
   "source": [
    "device = T.device('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76842e16",
   "metadata": {
    "id": "z6qUKC39rzwR"
   },
   "outputs": [],
   "source": [
    "bounds = T.tensor([\n",
    "    -2.39798704e+00,\n",
    "    -7.11248159e-01,\n",
    "    -3.26290283e-01,\n",
    "    -1.55338428e-04,\n",
    "    3.26182064e-01,\n",
    "    7.10855860e-01,\n",
    "    2.39811567e+00,\n",
    "]).to(device)\n",
    "\n",
    "\n",
    "# @T.jit.script\n",
    "def forward1(xs, bounds):\n",
    "    return T.searchsorted(bounds, xs).type(T.uint8)\n",
    "\n",
    "\n",
    "# @T.jit.script\n",
    "def forward2(xs, bounds):\n",
    "    value, pos = T.min(xs[..., None] > bounds, -1)\n",
    "    discr = T.where(value, len(bounds), pos).type(T.uint8)\n",
    "    return discr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aad3ffa",
   "metadata": {
    "id": "3bB41SCWr9LF"
   },
   "outputs": [],
   "source": [
    "xs = T.randn((128, 128)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d0723b6",
   "metadata": {
    "id": "_b0skuBgsZgE"
   },
   "outputs": [],
   "source": [
    "# Warm up JIT and make sanity check\n",
    "y1 = forward1(xs, bounds)\n",
    "y2 = forward2(xs, bounds)\n",
    "assert T.all(y1 == y2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fab5b77",
   "metadata": {
    "id": "997lStURsJn2",
    "outputId": "6b980cd5-66d8-4444-85d9-2b4cadfa2c24"
   },
   "outputs": [],
   "source": [
    "%timeit forward1(xs, bounds)\n",
    "%timeit forward2(xs, bounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "031f3a6b",
   "metadata": {
    "id": "i0ckMfwxvC_L"
   },
   "outputs": [],
   "source": [
    "t1 = benchmark.Timer(\n",
    "    stmt='forward1(xs, bounds)',\n",
    "    setup='from __main__ import forward1',\n",
    "    globals={'xs': xs, 'bounds': bounds})\n",
    "\n",
    "t2 = benchmark.Timer(\n",
    "    stmt='forward2(xs, bounds)',\n",
    "    setup='from __main__ import forward2',\n",
    "    globals={'xs': xs, 'bounds': bounds})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7a1264",
   "metadata": {
    "id": "xpCvWpAHvOsU",
    "outputId": "4acfc3a1-7c2e-4429-b821-14fcf631b573"
   },
   "outputs": [],
   "source": [
    "print(t1.timeit(10000))\n",
    "print(t2.timeit(10000))"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
