{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "134e7f9d",
   "metadata": {},
   "source": [
    "# Demo 4: Extracting activation functions\n",
    "\n",
    "The KAN diagrams give intuitive illustration, but sometimes we may also want to extract the values of activation functions for more quantitative tasks. Using the indexing convention introduced in the indexing notebook, each edge is indexed as $(l,i,j)$, where $l$ is the layer index, $i$ is the input neuron index, and $j$ is output neuron index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2075ef56",
   "metadata": {},
   "source": [
    "from kan import *\n",
    "import matplotlib.pyplot as plt\n",
    "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
    "model = KAN(width=[2,5,1], grid=5, k=3, seed=0, noise_scale_base = 1.0)\n",
    "x = torch.normal(0,1,size=(100,2))\n",
    "model(x)\n",
    "model.plot(beta=100)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d3fe2e03",
   "metadata": {},
   "source": [
    "l = 1\n",
    "i = 2\n",
    "j = 0\n",
    "\n",
    "inputs = model.spline_preacts[l][:,j,i]\n",
    "outputs = model.spline_postacts[l][:,j,i]\n",
    "# they are not ordered yet\n",
    "rank = np.argsort(inputs)\n",
    "inputs = inputs[rank]\n",
    "outputs = outputs[rank]\n",
    "plt.plot(inputs, outputs, marker=\"o\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "a9e62f17",
   "metadata": {},
   "source": [
    "If we are interested in the range of some activation function, we can use get_range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1a978202",
   "metadata": {},
   "source": [
    "model.get_range(l,i,j)"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
