{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a55c0b7b-3edb-4e8e-bf9e-240425ba988b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "checkpoint directory created: ./model\n",
      "saving model version 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.49e-02 | test_loss: 2.49e-02 | reg: 0.00e+00 | :  34%|▎| 34/100 [00:15<00:30,  2.16i\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 33\u001b[0m\n\u001b[1;32m     30\u001b[0m dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_input\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m inputs\n\u001b[1;32m     31\u001b[0m dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_label\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m targets\n\u001b[0;32m---> 33\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mLBFGS\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/research/pykan/kan/MultKAN.py:1552\u001b[0m, in \u001b[0;36mMultKAN.fit\u001b[0;34m(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, metrics, metrics_log_freq, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, singularity_avoiding, y_th, reg_metric, display_metrics)\u001b[0m\n\u001b[1;32m   1549\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdate_grid(dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_input\u001b[39m\u001b[38;5;124m'\u001b[39m][train_id])\n\u001b[1;32m   1551\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLBFGS\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m-> 1552\u001b[0m     \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1554\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAdam\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m   1555\u001b[0m     pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_input\u001b[39m\u001b[38;5;124m'\u001b[39m][train_id], singularity_avoiding\u001b[38;5;241m=\u001b[39msingularity_avoiding, y_th\u001b[38;5;241m=\u001b[39my_th)\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/optim/optimizer.py:280\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    276\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    277\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs),\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    278\u001b[0m                                \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 280\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    281\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m    283\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/research/pykan/kan/LBFGS.py:443\u001b[0m, in \u001b[0;36mLBFGS.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    441\u001b[0m     \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mobj_func\u001b[39m(x, t, d):\n\u001b[1;32m    442\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_directional_evaluate(closure, x, t, d)\n\u001b[0;32m--> 443\u001b[0m     loss, flat_grad, t, ls_func_evals \u001b[38;5;241m=\u001b[39m \u001b[43m_strong_wolfe\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    444\u001b[0m \u001b[43m        \u001b[49m\u001b[43mobj_func\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_init\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflat_grad\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgtd\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    445\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_add_grad(t, d)\n\u001b[1;32m    446\u001b[0m opt_cond \u001b[38;5;241m=\u001b[39m flat_grad\u001b[38;5;241m.\u001b[39mabs()\u001b[38;5;241m.\u001b[39mmax() \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m tolerance_grad\n",
      "File \u001b[0;32m~/research/pykan/kan/LBFGS.py:50\u001b[0m, in \u001b[0;36m_strong_wolfe\u001b[0;34m(obj_func, x, t, d, f, g, gtd, c1, c2, tolerance_change, max_ls)\u001b[0m\n\u001b[1;32m     48\u001b[0m g \u001b[38;5;241m=\u001b[39m g\u001b[38;5;241m.\u001b[39mclone(memory_format\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mcontiguous_format)\n\u001b[1;32m     49\u001b[0m \u001b[38;5;66;03m# evaluate objective and gradient using initial step\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m f_new, g_new \u001b[38;5;241m=\u001b[39m \u001b[43mobj_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     51\u001b[0m ls_func_evals \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     52\u001b[0m gtd_new \u001b[38;5;241m=\u001b[39m g_new\u001b[38;5;241m.\u001b[39mdot(d)\n",
      "File \u001b[0;32m~/research/pykan/kan/LBFGS.py:442\u001b[0m, in \u001b[0;36mLBFGS.step.<locals>.obj_func\u001b[0;34m(x, t, d)\u001b[0m\n\u001b[1;32m    441\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mobj_func\u001b[39m(x, t, d):\n\u001b[0;32m--> 442\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_directional_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43md\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/research/pykan/kan/LBFGS.py:291\u001b[0m, in \u001b[0;36mLBFGS._directional_evaluate\u001b[0;34m(self, closure, x, t, d)\u001b[0m\n\u001b[1;32m    289\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_directional_evaluate\u001b[39m(\u001b[38;5;28mself\u001b[39m, closure, x, t, d):\n\u001b[1;32m    290\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_add_grad(t, d)\n\u001b[0;32m--> 291\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    292\u001b[0m     flat_grad \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gather_flat_grad()\n\u001b[1;32m    293\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_set_param(x)\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/research/pykan/kan/MultKAN.py:1519\u001b[0m, in \u001b[0;36mMultKAN.fit.<locals>.closure\u001b[0;34m()\u001b[0m\n\u001b[1;32m   1517\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m   1518\u001b[0m pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_input\u001b[39m\u001b[38;5;124m'\u001b[39m][train_id], singularity_avoiding\u001b[38;5;241m=\u001b[39msingularity_avoiding, y_th\u001b[38;5;241m=\u001b[39my_th)\n\u001b[0;32m-> 1519\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m loss_fn(pred, \u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain_label\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtrain_id\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m   1520\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msave_act:\n\u001b[1;32m   1521\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m reg_metric \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124medge_backward\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from kan import *\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)\n",
    "\n",
    "model = 'mlp'\n",
    "n = 4\n",
    "steps = 10000\n",
    "\n",
    "# load in MLP data\n",
    "data = np.loadtxt(f'./results/mlp_sol_n_{n}_steps_{steps}')\n",
    "x = torch.linspace(-1,1,201)\n",
    "y = torch.linspace(-1,1,201)\n",
    "xx, yy = torch.meshgrid(x, y, indexing='ij')\n",
    "\n",
    "inputs = torch.stack([xx.reshape(-1,), yy.reshape(-1,)]).permute(1,0).to(device)\n",
    "targets = torch.from_numpy(data.reshape(201*201, 1)).to(device)\n",
    "\n",
    "grid = 20\n",
    "model = KAN(width=[2,10,1], grid=grid, k=3, seed=1, device=device)\n",
    "\n",
    "dataset = {}\n",
    "\n",
    "dataset['train_input'] = inputs\n",
    "dataset['train_label'] = targets\n",
    "\n",
    "dataset['test_input'] = inputs\n",
    "dataset['test_label'] = targets\n",
    "\n",
    "model.fit(dataset, opt='LBFGS')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "54149feb-abf3-4d8a-b643-732b81208b27",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(40401, 1)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targets.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56c92149-6df2-4640-86c9-5cec9167bd04",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
