{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "959a4583-0d30-4195-be8a-716cbacf62a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 2.43e-01 | 9.13e+02 :   1%|          | 81/10001 [00:01<02:40, 61.98it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 152\u001b[0m\n\u001b[1;32m    149\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(seed)\n\u001b[1;32m    151\u001b[0m lan \u001b[38;5;241m=\u001b[39m LAN(width\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m2\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m1\u001b[39m], grid\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m, k\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, base_fun\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39msin, w0\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m, scale_sp\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m, scale_sp_trainable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, weight_init_scale\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39msqrt(\u001b[38;5;241m6.\u001b[39m), linear_bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, device\u001b[38;5;241m=\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 152\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_LAN\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlan\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mAdam\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1024\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10001\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid_update_num\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstop_grid_update_step\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswitching\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbase\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mupdate_grid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m;\n\u001b[1;32m    154\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m4096\u001b[39m\n\u001b[1;32m    155\u001b[0m n_batch \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39mbatch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
      "Cell \u001b[0;32mIn[5], line 105\u001b[0m, in \u001b[0;36mtrain_LAN\u001b[0;34m(model, dataset, opt, steps, log, lamb, act_l1, act_entropy, weight_l1, update_grid, grid_update_num, stop_grid_update_step, batch, switching, name)\u001b[0m\n\u001b[1;32m    102\u001b[0m         objective\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m    103\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m objective\n\u001b[0;32m--> 105\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m loss_fn(\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain_input\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43mtrain_id\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m, dataset[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_label\u001b[39m\u001b[38;5;124m'\u001b[39m][train_id]\u001b[38;5;241m.\u001b[39mto(device))\n\u001b[1;32m    106\u001b[0m reg_ \u001b[38;5;241m=\u001b[39m reg(model\u001b[38;5;241m.\u001b[39macts_scale)\n\u001b[1;32m    107\u001b[0m loss \u001b[38;5;241m=\u001b[39m train_loss \u001b[38;5;241m+\u001b[39m lamb\u001b[38;5;241m*\u001b[39mreg_\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/research/pykan/experiments/baselines/LAN.py:228\u001b[0m, in \u001b[0;36mLAN.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    226\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m l \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdepth \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m    227\u001b[0m     x_postlinear \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mw0\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinears[l](x)\n\u001b[0;32m--> 228\u001b[0m     x, preacts, postacts, postspline \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mact_fun\u001b[49m\u001b[43m[\u001b[49m\u001b[43ml\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_postlinear\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    229\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    230\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlinears[l](x)\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1499\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1500\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/research/pykan/experiments/baselines/LAN.py:89\u001b[0m, in \u001b[0;36mSpline_batch_LAN.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     87\u001b[0m preacts \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mclone()\n\u001b[1;32m     88\u001b[0m base \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_fun(x)\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# shape (batch, size)\u001b[39;00m\n\u001b[0;32m---> 89\u001b[0m y \u001b[38;5;241m=\u001b[39m \u001b[43mcoef2curve\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcoef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcoef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# shape (size, batch)\u001b[39;00m\n\u001b[1;32m     90\u001b[0m y \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# shape (batch, size)\u001b[39;00m\n\u001b[1;32m     91\u001b[0m postspline \u001b[38;5;241m=\u001b[39m y\u001b[38;5;241m.\u001b[39mclone()\n",
      "File \u001b[0;32m~/research/pykan/experiments/baselines/LAN.py:52\u001b[0m, in \u001b[0;36mcoef2curve\u001b[0;34m(x_eval, grid, coef, k)\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcoef2curve\u001b[39m(x_eval, grid, coef, k):\n\u001b[1;32m     50\u001b[0m         \u001b[38;5;66;03m# x_eval: (size, batch), grid: (size, grid), coef: (size, coef)\u001b[39;00m\n\u001b[1;32m     51\u001b[0m         \u001b[38;5;66;03m# coef: (size, coef), B_batch: (size, coef, batch), summer over coef\u001b[39;00m\n\u001b[0;32m---> 52\u001b[0m         y_eval \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39meinsum(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mij,ijk->ik\u001b[39m\u001b[38;5;124m'\u001b[39m, coef, \u001b[43mB_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx_eval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m     53\u001b[0m         \u001b[38;5;66;03m#y_eval = torch.sum(coef.unsqueeze(dim=1)*B_batch(x_eval, grid, k), dim=0)\u001b[39;00m\n\u001b[1;32m     54\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m y_eval\n",
      "File \u001b[0;32m~/research/pykan/experiments/baselines/LAN.py:44\u001b[0m, in \u001b[0;36mB_batch\u001b[0;34m(x, grid, k, extend)\u001b[0m\n\u001b[1;32m     42\u001b[0m     value \u001b[38;5;241m=\u001b[39m (x\u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39mgrid[:,:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\u001b[38;5;241m*\u001b[39m(x\u001b[38;5;241m<\u001b[39mgrid[:,\u001b[38;5;241m1\u001b[39m:])\n\u001b[1;32m     43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 44\u001b[0m     B_km1 \u001b[38;5;241m=\u001b[39m \u001b[43mB_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mextend\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     45\u001b[0m     value \u001b[38;5;241m=\u001b[39m (x\u001b[38;5;241m-\u001b[39mgrid[:,:\u001b[38;5;241m-\u001b[39m(k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)])\u001b[38;5;241m/\u001b[39m(grid[:,k:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m-\u001b[39mgrid[:,:\u001b[38;5;241m-\u001b[39m(k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)])\u001b[38;5;241m*\u001b[39mB_km1[:,:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m+\u001b[39m (grid[:,k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m:]\u001b[38;5;241m-\u001b[39mx)\u001b[38;5;241m/\u001b[39m(grid[:,k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m:]\u001b[38;5;241m-\u001b[39mgrid[:,\u001b[38;5;241m1\u001b[39m:(\u001b[38;5;241m-\u001b[39mk)])\u001b[38;5;241m*\u001b[39mB_km1[:,\u001b[38;5;241m1\u001b[39m:]\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
      "File \u001b[0;32m~/research/pykan/experiments/baselines/LAN.py:45\u001b[0m, in \u001b[0;36mB_batch\u001b[0;34m(x, grid, k, extend)\u001b[0m\n\u001b[1;32m     43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m     44\u001b[0m     B_km1 \u001b[38;5;241m=\u001b[39m B_batch(x[:,\u001b[38;5;241m0\u001b[39m],grid\u001b[38;5;241m=\u001b[39mgrid[:,:,\u001b[38;5;241m0\u001b[39m],k\u001b[38;5;241m=\u001b[39mk\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,extend\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m---> 45\u001b[0m     value \u001b[38;5;241m=\u001b[39m (\u001b[43mx\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43mgrid\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m)\u001b[38;5;241m/\u001b[39m(grid[:,k:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m-\u001b[39mgrid[:,:\u001b[38;5;241m-\u001b[39m(k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m)])\u001b[38;5;241m*\u001b[39mB_km1[:,:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m+\u001b[39m (grid[:,k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m:]\u001b[38;5;241m-\u001b[39mx)\u001b[38;5;241m/\u001b[39m(grid[:,k\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m:]\u001b[38;5;241m-\u001b[39mgrid[:,\u001b[38;5;241m1\u001b[39m:(\u001b[38;5;241m-\u001b[39mk)])\u001b[38;5;241m*\u001b[39mB_km1[:,\u001b[38;5;241m1\u001b[39m:]\n\u001b[1;32m     46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from experiments.baselines.LAN import *\n",
    "from kan import *\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)\n",
    "\n",
    "image = np.array(Image.open('./cameraman.png').convert('L'))\n",
    "image = 2*(image/256 - 0.5)\n",
    "\n",
    "dimx, dimy = image.shape\n",
    "x_grid = np.linspace(-1,1,num=dimx)\n",
    "y_grid = np.linspace(-1,1,num=dimy)\n",
    "xx, yy = np.meshgrid(x_grid, y_grid)\n",
    "inputs = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)]))\n",
    "labels = image.reshape(-1,)\n",
    "num = labels.shape[0]\n",
    "\n",
    "dataset = {}\n",
    "dataset['train_input'] = torch.tensor(inputs, dtype=torch.float32, requires_grad=True)\n",
    "dataset['train_label'] = torch.tensor(labels[:,np.newaxis], dtype=torch.float32, requires_grad=True)\n",
    "\n",
    "def PSNR(original, compressed): \n",
    "    mse = np.mean((original - compressed) ** 2) \n",
    "    if(mse == 0):  # MSE is zero means no noise is present in the signal . \n",
    "                  # Therefore PSNR have no importance. \n",
    "        return 100\n",
    "    max_pixel = 255.0\n",
    "    psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) \n",
    "    return psnr \n",
    "\n",
    "\n",
    "def train_LAN(model, dataset, opt=\"LBFGS\", steps=100, log=1, lamb=0., act_l1 = 1, act_entropy = 1, weight_l1 = 2, update_grid=True, grid_update_num=10, stop_grid_update_step=50, batch=-1, switching=False, name=\"noswitching\"):\n",
    "    \n",
    "    def reg(acts_scale):\n",
    "        reg_ = 0.\n",
    "        for i in range(len(acts_scale)):\n",
    "            vec = acts_scale[i].reshape(-1,)\n",
    "            p = vec/torch.sum(vec)\n",
    "            reg_ += act_l1*torch.sum(vec) - act_entropy*torch.sum(p*torch.log2(p+1e-4)) # both l1 and entropy\n",
    "\n",
    "        for i in range(len(model.linears)):\n",
    "            reg_ += weight_l1 * torch.sum(torch.abs(model.linears[i].weight))\n",
    "\n",
    "        return reg_\n",
    "\n",
    "    pbar = tqdm(range(steps), desc='description')\n",
    "\n",
    "    loss_fn = lambda x,y: torch.mean((x-y)**2)\n",
    "    \n",
    "    grid_update_freq = int(stop_grid_update_step/grid_update_num)\n",
    "\n",
    "    if opt == \"Adam\":\n",
    "        #optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "        optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-4)\n",
    "    elif opt == \"SGD\":\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "    elif opt == \"LBFGS\":\n",
    "        optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
    "        #optimizer = LBFGS(model.parameters(), lr=0.001, history_size=10, line_search_fn=None, tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
    "\n",
    "    results = {}\n",
    "    results['train_loss'] = []\n",
    "    results['test1_loss'] = []\n",
    "    results['test2_loss'] = []\n",
    "    results['reg'] = []\n",
    "    results['psnr'] = []\n",
    "    \n",
    "    if batch == -1:\n",
    "        batch_size = dataset['train_input'].shape[0]\n",
    "    else:\n",
    "        batch_size = batch\n",
    "\n",
    "    for _ in pbar:\n",
    "        \n",
    "        if _ == 5000:\n",
    "            if name == \"lanstart\" or name == \"lancontinue\":\n",
    "                optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-5)\n",
    "            if name == \"base\":\n",
    "                for g in optimizer.param_groups:\n",
    "                    g['lr'] *= 0.1\n",
    "                    \n",
    "        if _ == 10000:\n",
    "            for g in optimizer.param_groups:\n",
    "                g['lr'] *= 0.1\n",
    "                \n",
    "            \n",
    "        train_id = np.random.choice(dataset['train_input'].shape[0], batch_size)\n",
    "        \n",
    "        if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:\n",
    "            model.update_grid_from_samples(dataset['train_input'][train_id[:1000]].to(device))\n",
    "\n",
    "        \n",
    "        if opt == \"LBFGS\":\n",
    "            def closure():\n",
    "                optimizer.zero_grad()\n",
    "                pred_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))\n",
    "                reg_ = reg(model.acts_scale)\n",
    "                objective = pred_loss + lamb*reg_\n",
    "                objective.backward()\n",
    "                return objective\n",
    "\n",
    "        train_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))\n",
    "        reg_ = reg(model.acts_scale)\n",
    "        loss = train_loss + lamb*reg_\n",
    "        \n",
    "        if _ % log == 0:\n",
    "            pbar.set_description(\" %.2e | %.2e \" % (torch.sqrt(train_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))\n",
    "\n",
    "        '''if _ % 100 == 0:\n",
    "            batch = 4096\n",
    "            n_batch = inputs.shape[0]//batch + 1\n",
    "            for i in range(n_batch):\n",
    "                data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.double).to(device)\n",
    "                if i == 0:\n",
    "                    out = lan(data_batch).cpu().detach()\n",
    "                else:\n",
    "                    out = torch.cat([out, lan(data_batch).cpu().detach()], dim=0)\n",
    "                    \n",
    "            compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128\n",
    "            original = (image + 1) * 128\n",
    "            psnr = PSNR(original, compressed)\n",
    "            results['psnr'].append(psnr)\n",
    "        \n",
    "\n",
    "            plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy())\n",
    "            plt.axis('off')\n",
    "            plt.gray()\n",
    "            plt.savefig('./siren/run_%s_step_%d.png'%(name, _), bbox_inches=\"tight\")\n",
    "            plt.close()'''\n",
    "\n",
    "        if opt == \"Adam\" or opt == \"SGD\":\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        elif opt == \"LBFGS\":\n",
    "            optimizer.step(closure)\n",
    "\n",
    "\n",
    "        results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())\n",
    "        results['reg'].append(reg_.cpu().detach().numpy())\n",
    "\n",
    "    return results\n",
    "\n",
    "seed = 1\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "lan = LAN(width=[2,128,128,128,128,1], grid=10, k=3, base_fun=torch.sin, w0=30, scale_sp=0.0, scale_sp_trainable=False, weight_init_scale=np.sqrt(6.), linear_bias=True, device=device).to(device)\n",
    "results = train_LAN(lan, dataset, opt=\"Adam\", batch=1024, steps=10001, grid_update_num=100, stop_grid_update_step=5000, switching=False, name=\"base\", update_grid=False);\n",
    "\n",
    "batch = 4096\n",
    "n_batch = inputs.shape[0]//batch + 1\n",
    "for i in range(n_batch):\n",
    "    if i % 20 == 0:\n",
    "        print(i)\n",
    "    data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.float32).to(device)\n",
    "    if i == 0:\n",
    "        out = model(data_batch).cpu().detach()\n",
    "    else:\n",
    "        out = torch.cat([out, model(data_batch).cpu().detach()], dim=0)\n",
    "        \n",
    "compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128\n",
    "original = (image + 1) * 128\n",
    "psnr = PSNR(original, compressed)\n",
    "plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy(), cmap='gray')\n",
    "plt.title('psnr=%.2f'%psnr, fontsize=15)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659332b1-3115-42ee-a7be-91bfd3c61633",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
