{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "134e7f9d",
   "metadata": {},
   "source": [
    "# Demo 1: Indexing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2075ef56",
   "metadata": {},
   "source": [
    "from kan import KAN\n",
    "import torch\n",
    "model = KAN(width=[2,3,2,1])\n",
    "x = torch.normal(0,1,size=(100,2))\n",
    "model(x);\n",
    "beta = 100\n",
    "model.plot(beta=beta)\n",
    "# [2,3,2,1] means 2 input nodes\n",
    "# 3 neurons in the first hidden layer,\n",
    "# 2 neurons in the second hidden layer,\n",
    "# 1 output node"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "c47ccd2b",
   "metadata": {},
   "source": [
    "### Indexing of edges (activation functions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c30add2",
   "metadata": {},
   "source": [
    "Each activation function is indexed by $(l,i,j)$ where $l$ is the layer index, $i$ is the input neuron index, $j$ is the output neuron index. All of them starts from 0. For example, the one in the bottom left corner is (0, 0, 0). Let's try to make it symbolic and see it turns red."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c95dbc78",
   "metadata": {},
   "source": [
    "model.fix_symbolic(0,0,0,'sin')\n",
    "model.plot(beta=beta)\n",
    "model.unfix_symbolic(0,0,0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bf721202",
   "metadata": {},
   "source": [
    "model.fix_symbolic(0,0,1,'sin')\n",
    "model.plot(beta=beta)\n",
    "model.unfix_symbolic(0,0,1)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e7cd4a8",
   "metadata": {},
   "source": [
    "model.fix_symbolic(0,1,0,'sin')\n",
    "model.plot(beta=beta)\n",
    "model.unfix_symbolic(0,1,0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "18e0baa2",
   "metadata": {},
   "source": [
    "model.fix_symbolic(1,0,0,'sin')\n",
    "model.plot(beta=beta)\n",
    "model.unfix_symbolic(1,0,0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "50eb8f8c",
   "metadata": {},
   "source": [
    "model.fix_symbolic(2,1,0,'sin')\n",
    "model.plot(beta=beta)\n",
    "model.unfix_symbolic(2,1,0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "960e5447",
   "metadata": {},
   "source": [
    "### Indexing of nodes (neurons)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4a7880f",
   "metadata": {},
   "source": [
    "Each neuron (node) is indexed by $(l,i)$ where $l$ is the layer index along depth, $i$ is the neuron index along width. In the function remove_node, we use use $(l,i)$ to indicate which node we want to remove."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c9e70d77",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "model.remove_node(1,0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a22c9e31",
   "metadata": {},
   "source": [
    "model.plot(beta=beta)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "44553b6a",
   "metadata": {},
   "source": [
    "model.remove_node(2,1)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7c9b491a",
   "metadata": {},
   "source": [
    "model.plot(beta=beta)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7b24fcdb",
   "metadata": {},
   "source": [
    "model.remove_node(1,2)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0a7e9373",
   "metadata": {},
   "source": [
    "model.plot(beta=beta)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "9ee64af1",
   "metadata": {},
   "source": [
    "### Indexing of layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4c732dfc",
   "metadata": {},
   "source": [
    "# KAN spline layers are refererred to as act_fun\n",
    "# KAN symbolic layers are referred to as symbolic_fun\n",
    "\n",
    "model = KAN(width=[2,3,2,1])\n",
    "\n",
    "i = 0\n",
    "model.act_fun[i] # => KAN Layer (Spline)\n",
    "model.symbolic_fun[i] # => KAN Layer (Symbolic)\n",
    "\n",
    "for i in range(3):\n",
    "    print(model.act_fun[i].in_dim, model.act_fun[i].out_dim)\n",
    "    print(model.symbolic_fun[i].in_dim, model.symbolic_fun[i].out_dim)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1f0ccc8f",
   "metadata": {},
   "source": [
    "# check model parameters\n",
    "model.act_fun[i].grid\n",
    "model.act_fun[i].coef\n",
    "model.symbolic_fun[i].funs_name\n",
    "model.symbolic_fun[i].mask"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a2cbe3",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
