{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "134e7f9d",
   "metadata": {},
   "source": [
    "# Demo 2: Plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2571d531",
   "metadata": {},
   "source": [
    "Initialize KAN and create dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2075ef56",
   "metadata": {},
   "source": [
    "from kan import *\n",
    "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
    "model = KAN(width=[2,5,1], grid=5, k=3, seed=0)\n",
    "\n",
    "# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
    "f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "dataset = create_dataset(f, n_var=2)\n",
    "dataset['train_input'].shape, dataset['train_label'].shape"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "8c6add1d",
   "metadata": {},
   "source": [
    "Plot KAN at initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ac76f858",
   "metadata": {},
   "source": [
    "# plot KAN at initialization\n",
    "model(dataset['train_input']);\n",
    "model.plot(beta=100)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8071b133",
   "metadata": {},
   "source": [
    "# if you want to add variable names and title\n",
    "model.plot(beta=100, in_vars=[r'$\\alpha$', 'x'], out_vars=['y'], title = 'My KAN')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "ddf67e30",
   "metadata": {},
   "source": [
    "Train KAN with sparsity regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "97111d75",
   "metadata": {},
   "source": [
    "# train the model\n",
    "model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "2f30c3ab",
   "metadata": {},
   "source": [
    "$\\beta$ controls the transparency of activations. Larger $\\beta$ => more activation functions show up. We usually want to set a proper beta such that only important connections are visually significant. transparency is set to be ${\\rm tanh}(\\beta |\\phi|_1)$ where $|\\phi|_1$ is the l1 norm of the activation function. By default $\\beta=3$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3f95fcdd",
   "metadata": {},
   "source": [
    "model.plot(beta=3)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6383a22f",
   "metadata": {},
   "source": [
    "model.plot(beta=100000)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9f6ff7e1",
   "metadata": {},
   "source": [
    "model.plot(beta=0.1)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "61d537b7",
   "metadata": {},
   "source": [
    "After purning, \"mask=True\" will remove all connections that are connected to unsignificant neurons. The insignificant neurons themselves are still visualized. If you want those neurons to be removed as well, see below. Insignificant/Significant neurons are defined based on l1 norm of its incoming and outgoing functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1269a698",
   "metadata": {},
   "source": [
    "model.prune()\n",
    "model.plot(mask=True)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2249bb17",
   "metadata": {},
   "source": [
    "model.plot(mask=True, beta=100000)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1e5b8732",
   "metadata": {},
   "source": [
    "model.plot(mask=True, beta=0.1)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "9e788b91",
   "metadata": {},
   "source": [
    "Remove insignificant neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ed4800ea",
   "metadata": {},
   "source": [
    "model2 = model.prune()\n",
    "model2(dataset['train_input']) # it's important to do a forward first to collect activations\n",
    "model2.plot()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "61c8eeb1",
   "metadata": {},
   "source": [
    "Resize the figure using the \"scale\" parameter. By default: 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5cb8d57e",
   "metadata": {},
   "source": [
    "model2.plot(scale=0.5)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "67305f39",
   "metadata": {},
   "source": [
    "model2.plot(scale=0.2)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "51c722ad",
   "metadata": {},
   "source": [
    "model2.plot(scale=2.0)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "03d4bf1b",
   "metadata": {},
   "source": [
    "If you want to see sample distribution in addition to the line, set \"sample=True\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c6d24148",
   "metadata": {},
   "source": [
    "model2(dataset['train_input'])\n",
    "model2.plot(sample=True)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "a3fa482a",
   "metadata": {},
   "source": [
    "The samples are more visible if we use a smaller number of samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3856bcb6",
   "metadata": {},
   "source": [
    "model2(dataset['train_input'][:20])\n",
    "model2.plot(sample=True)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "4fa7ca2c",
   "metadata": {},
   "source": [
    "If a function is set to be symbolic, it becomes red"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3d502880",
   "metadata": {},
   "source": [
    "model2.fix_symbolic(0,1,0,'x^2')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f8f93b9c",
   "metadata": {},
   "source": [
    "model2.plot()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "e75a0760",
   "metadata": {},
   "source": [
    "If a function is set to be both symbolic and numeric (its output is the addition of symbolic and spline), then it shows up in purple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "17df5fed",
   "metadata": {},
   "source": [
    "model2.set_mode(0,1,0,mode='ns')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b5b13363",
   "metadata": {},
   "source": [
    "model2.plot()"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
