{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "134e7f9d",
   "metadata": {},
   "source": [
    "# Demo 8: Checkpoint\n",
    "\n",
    "It is fun to play with KANs, just the same it is fun to play computer games. A common frustration in both games is that one did something wrong but cannot restore to the lastest checkpoint. We provide a quick way to save and load your checkpoint, so that you won't be frustrated and think that you need to start all over again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2075ef56",
   "metadata": {},
   "source": [
    "from kan import KAN, create_dataset\n",
    "import torch\n",
    "import torch.nn\n",
    "# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
    "model = KAN(width=[2,5,1], grid=5, k=3, seed=0, base_fun=torch.nn.SiLU())\n",
    "f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "dataset = create_dataset(f, n_var=2)\n",
    "\n",
    "model(dataset['train_input'])\n",
    "model.plot()\n",
    "model.save_ckpt('ckpt1')\n",
    "#model.clear_ckpts()\n",
    "# save intialized model as ckpt1"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ab90723b",
   "metadata": {},
   "source": [
    "model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);\n",
    "model.plot()\n",
    "model.save_ckpt('ckpt2')\n",
    "# save the trained model as ckpt2"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "6b339ba4",
   "metadata": {},
   "source": [
    "The above results look promising! You probably want to further simplify it down by further training it or pruning it. Suppose you want to pump up regularization strengh to make the graph cleaner, but you set the strength to be too large and training messes the whole thing up."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0b580553",
   "metadata": {},
   "source": [
    "model.train(dataset, opt=\"Adam\", steps=20, lamb=100., lamb_entropy=10.);\n",
    "model.plot()\n",
    "model.save_ckpt('ckpt3')"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "4300604e",
   "metadata": {},
   "source": [
    "We want to recover to ckpt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "705661e5",
   "metadata": {},
   "source": [
    "model.load_ckpt('ckpt2')\n",
    "model(dataset['train_input'])\n",
    "model.plot()"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "686b8fb4",
   "metadata": {},
   "source": [
    "Now we realize that pruning it seems a better choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0598e1f0",
   "metadata": {},
   "source": [
    "model = model.prune()\n",
    "model(dataset['train_input'])\n",
    "model.plot()"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
