{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "959a4583-0d30-4195-be8a-716cbacf62a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "load dataset\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 2.13e-01 | 9.38e+02 :   2%|▏         | 232/10001 [00:04<03:01, 53.84it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 155\u001b[0m\n\u001b[1;32m    152\u001b[0m torch\u001b[38;5;241m.\u001b[39mmanual_seed(seed)\n\u001b[1;32m    154\u001b[0m lan \u001b[38;5;241m=\u001b[39m LAN(width\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m2\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m128\u001b[39m,\u001b[38;5;241m1\u001b[39m], grid\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m, k\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m, base_fun\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39msin, w0\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m, scale_sp\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.0\u001b[39m, scale_sp_trainable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, weight_init_scale\u001b[38;5;241m=\u001b[39mnp\u001b[38;5;241m.\u001b[39msqrt(\u001b[38;5;241m6.\u001b[39m), linear_bias\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, device\u001b[38;5;241m=\u001b[39mdevice)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 155\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_LAN\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlan\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mAdam\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1024\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10001\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrid_update_num\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstop_grid_update_step\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswitching\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbase\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mupdate_grid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m;\n\u001b[1;32m    157\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m4096\u001b[39m\n\u001b[1;32m    158\u001b[0m n_batch \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39mbatch \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
      "Cell \u001b[0;32mIn[4], line 139\u001b[0m, in \u001b[0;36mtrain_LAN\u001b[0;34m(model, dataset, opt, steps, log, lamb, act_l1, act_entropy, weight_l1, update_grid, grid_update_num, stop_grid_update_step, batch, switching, name)\u001b[0m\n\u001b[1;32m    137\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m opt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAdam\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m opt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSGD\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    138\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m--> 139\u001b[0m     \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    140\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m    141\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m opt \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLBFGS\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    478\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    479\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    480\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    485\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    486\u001b[0m     )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    488\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/state/partition1/llgrid/pkg/anaconda/anaconda3-2023a-pytorch/lib/python3.9/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    195\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from experiments.baselines.LAN import *\n",
    "from kan import *\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(device)\n",
    "\n",
    "image = np.array(Image.open('./cameraman.png').convert('L'))\n",
    "image = 2*(image/256 - 0.5)\n",
    "\n",
    "dimx, dimy = image.shape\n",
    "x_grid = np.linspace(-1,1,num=dimx)\n",
    "y_grid = np.linspace(-1,1,num=dimy)\n",
    "xx, yy = np.meshgrid(x_grid, y_grid)\n",
    "inputs = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)]))\n",
    "labels = image.reshape(-1,)\n",
    "num = labels.shape[0]\n",
    "\n",
    "dataset = {}\n",
    "dataset['train_input'] = torch.tensor(inputs, dtype=torch.float32, requires_grad=True)\n",
    "dataset['train_label'] = torch.tensor(labels[:,np.newaxis], dtype=torch.float32, requires_grad=True)\n",
    "\n",
    "def PSNR(original, compressed): \n",
    "    mse = np.mean((original - compressed) ** 2) \n",
    "    if(mse == 0):  # MSE is zero means no noise is present in the signal . \n",
    "                  # Therefore PSNR have no importance. \n",
    "        return 100\n",
    "    max_pixel = 255.0\n",
    "    psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) \n",
    "    return psnr \n",
    "\n",
    "\n",
    "def train_LAN(model, dataset, opt=\"LBFGS\", steps=100, log=1, lamb=0., act_l1 = 1, act_entropy = 1, weight_l1 = 2, update_grid=True, grid_update_num=10, stop_grid_update_step=50, batch=-1, switching=False, name=\"noswitching\"):\n",
    "    \n",
    "    def reg(acts_scale):\n",
    "        reg_ = 0.\n",
    "        for i in range(len(acts_scale)):\n",
    "            vec = acts_scale[i].reshape(-1,)\n",
    "            p = vec/torch.sum(vec)\n",
    "            reg_ += act_l1*torch.sum(vec) - act_entropy*torch.sum(p*torch.log2(p+1e-4)) # both l1 and entropy\n",
    "\n",
    "        for i in range(len(model.linears)):\n",
    "            reg_ += weight_l1 * torch.sum(torch.abs(model.linears[i].weight))\n",
    "\n",
    "        return reg_\n",
    "\n",
    "    pbar = tqdm(range(steps), desc='description')\n",
    "\n",
    "    loss_fn = lambda x,y: torch.mean((x-y)**2)\n",
    "    \n",
    "    grid_update_freq = int(stop_grid_update_step/grid_update_num)\n",
    "\n",
    "    if opt == \"Adam\":\n",
    "        #optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "        optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-4)\n",
    "    elif opt == \"SGD\":\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n",
    "    elif opt == \"LBFGS\":\n",
    "        optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
    "        #optimizer = LBFGS(model.parameters(), lr=0.001, history_size=10, line_search_fn=None, tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
    "\n",
    "    results = {}\n",
    "    results['train_loss'] = []\n",
    "    results['test1_loss'] = []\n",
    "    results['test2_loss'] = []\n",
    "    results['reg'] = []\n",
    "    results['psnr'] = []\n",
    "    \n",
    "    if batch == -1:\n",
    "        batch_size = dataset['train_input'].shape[0]\n",
    "    else:\n",
    "        batch_size = batch\n",
    "\n",
    "    for _ in pbar:\n",
    "        \n",
    "        if _ == 5000:\n",
    "            if name == \"lanstart\" or name == \"lancontinue\":\n",
    "                optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-5)\n",
    "            if name == \"base\":\n",
    "                for g in optimizer.param_groups:\n",
    "                    g['lr'] *= 0.1\n",
    "                    \n",
    "        if _ == 10000:\n",
    "            for g in optimizer.param_groups:\n",
    "                g['lr'] *= 0.1\n",
    "                \n",
    "            \n",
    "        train_id = np.random.choice(dataset['train_input'].shape[0], batch_size)\n",
    "        \n",
    "        if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:\n",
    "            model.update_grid_from_samples(dataset['train_input'][train_id[:1000]].to(device))\n",
    "\n",
    "        \n",
    "        if opt == \"LBFGS\":\n",
    "            def closure():\n",
    "                optimizer.zero_grad()\n",
    "                pred_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))\n",
    "                reg_ = reg(model.acts_scale)\n",
    "                objective = pred_loss + lamb*reg_\n",
    "                objective.backward()\n",
    "                return objective\n",
    "\n",
    "        train_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))\n",
    "        reg_ = reg(model.acts_scale)\n",
    "        loss = train_loss + lamb*reg_\n",
    "        \n",
    "        if _ % log == 0:\n",
    "            pbar.set_description(\" %.2e | %.2e \" % (torch.sqrt(train_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))\n",
    "\n",
    "        '''if _ % 100 == 0:\n",
    "            batch = 4096\n",
    "            n_batch = inputs.shape[0]//batch + 1\n",
    "            for i in range(n_batch):\n",
    "                data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.double).to(device)\n",
    "                if i == 0:\n",
    "                    out = lan(data_batch).cpu().detach()\n",
    "                else:\n",
    "                    out = torch.cat([out, lan(data_batch).cpu().detach()], dim=0)\n",
    "                    \n",
    "            compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128\n",
    "            original = (image + 1) * 128\n",
    "            psnr = PSNR(original, compressed)\n",
    "            results['psnr'].append(psnr)\n",
    "        \n",
    "\n",
    "            plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy())\n",
    "            plt.axis('off')\n",
    "            plt.gray()\n",
    "            plt.savefig('./siren/run_%s_step_%d.png'%(name, _), bbox_inches=\"tight\")\n",
    "            plt.close()'''\n",
    "\n",
    "        if opt == \"Adam\" or opt == \"SGD\":\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        elif opt == \"LBFGS\":\n",
    "            optimizer.step(closure)\n",
    "\n",
    "\n",
    "        results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())\n",
    "        results['reg'].append(reg_.cpu().detach().numpy())\n",
    "\n",
    "    return results\n",
    "\n",
    "seed = 1\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "lan = LAN(width=[2,128,128,128,128,1], grid=10, k=3, base_fun=torch.sin, w0=30, scale_sp=0.0, scale_sp_trainable=False, weight_init_scale=np.sqrt(6.), linear_bias=True, device=device).to(device)\n",
    "results = train_LAN(lan, dataset, opt=\"Adam\", batch=1024, steps=10001, grid_update_num=100, stop_grid_update_step=5000, switching=False, name=\"base\", update_grid=False);\n",
    "\n",
    "batch = 4096\n",
    "n_batch = inputs.shape[0]//batch + 1\n",
    "for i in range(n_batch):\n",
    "    if i % 20 == 0:\n",
    "        print(i)\n",
    "    data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.float32).to(device)\n",
    "    if i == 0:\n",
    "        out = model(data_batch).cpu().detach()\n",
    "    else:\n",
    "        out = torch.cat([out, model(data_batch).cpu().detach()], dim=0)\n",
    "        \n",
    "compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128\n",
    "original = (image + 1) * 128\n",
    "psnr = PSNR(original, compressed)\n",
    "plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy(), cmap='gray')\n",
    "plt.title('psnr=%.2f'%psnr, fontsize=15)\n",
    "plt.axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "659332b1-3115-42ee-a7be-91bfd3c61633",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
