{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "21fd1644-b127-484b-b9b4-7389b0561a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model: KAN, MLP, MLP w/ random fourier feature\n",
    "# dataset: exp(sin(px + y^2)), exp(sum sin^2/100)\n",
    "# KAN (No grid extension) 900: depth=[2,3,4], width=[1,3,10,30,100], grid = [3,5,10,20,50,100], opt=Adam/LBFGS, lr = [1e-4,3e-4,1e-3,3e-3,1e-2] for Adam, [1e-2,3e-2,1e-1,3e-1,1] for LBFGS\n",
    "# KAN (grid extension) 150: depth=[2,3,4], width=[1,3,10,30,100], grid = [3,5,10,20,50,100], opt=Adam/LBFGS, lr = [1e-4,3e-4,1e-3,3e-3,1e-2] for Adam, [1e-2,3e-2,1e-1,3e-1,1] for LBFGS\n",
    "# MLP 150: depth=[2,3,4], width=[1,3,10,30,100], opt=Adam/LBFGS, lr = [1e-4,3e-4,1e-3,3e-3,1e-2]\n",
    "# MLP rff 750: depth=[2,3,4], width=[1,3,10,30,100], opt=Adam/LBFGS, lr = [1e-4,3e-4,1e-3,3e-3,1e-2], s = [0.3,1,3,10,30]\n",
    "# save: train/test loss evolution, wall time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "280322c6-2217-4bd8-85ab-133eab9f6043",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 1, 1]\n",
      "checkpoint directory created: ./model\n",
      "saving model version 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.31e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 27.77it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.1\n",
      "saving model version 0.2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.31e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 33.00it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.3\n",
      "saving model version 0.4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.31e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 30.59it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.5\n",
      "saving model version 0.6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.30e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 26.64it/s\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.7\n",
      "saving model version 0.8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.30e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 25.40it/s"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 2.29e+00 | test_loss: 2.30e+00 | reg: 2.89e+00 | : 100%|█| 1/1 [00:00<00:00, 25.41it/s"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# KAN (w/ grid extension)\n",
    "\n",
    "from kan import *\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "problem = 0 # 0 or 1\n",
    "opt_name = \"Adam\" #[\"Adam\", \"LBFGS\"]\n",
    "width = 1 # [1,3,10,30,100]\n",
    "depth = 2 # [2,3,4]\n",
    "lrs_lbfgs = [1e-2,3e-2,1e-1,3e-1,1]\n",
    "lrs_adam = [1e-4,3e-4,1e-3,3e-3,1e-2]\n",
    "lr_id = 0 # [0,1,2,3,4]\n",
    "\n",
    "if opt_name == \"Adam\":\n",
    "    lr = lrs_adam[lr_id]\n",
    "else:\n",
    "    lr = lrs_lbfgs[lr_id]\n",
    "\n",
    "# create dataset\n",
    "if problem == 0:\n",
    "    f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "    dataset = create_dataset(f, n_var=2, device=device, train_num=1000)\n",
    "    steps = 1 #200\n",
    "else:\n",
    "    def f(x):\n",
    "        return torch.exp(torch.mean(torch.sin(torch.pi/2*x)**2, dim=1))\n",
    "    dataset = create_dataset(f, n_var=100, device=device, train_num=1000)\n",
    "    steps = 1 #10000\n",
    "    \n",
    "\n",
    "\n",
    "grids = np.array([3,5,10,20,50,100])\n",
    "#grids = np.array([3,10])\n",
    "\n",
    "\n",
    "train_losses = []\n",
    "test_losses = []\n",
    "k = 3\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "for i in range(grids.shape[0]):\n",
    "    if i == 0:\n",
    "        if problem == 0:\n",
    "            print([2]+(depth-1)*[width]+[1])\n",
    "            model = KAN(width=[2]+(depth-1)*[width]+[1], grid=grids[i], k=k, seed=0, device=device)\n",
    "        else:\n",
    "            model = KAN(width=[100]+(depth-1)*[width]+[1], grid=grids[i], k=k, seed=0, device=device)\n",
    "    if i != 0:\n",
    "        model = model.refine(grids[i])\n",
    "    results = model.fit(dataset, opt=opt_name, steps=steps, lr=lr)\n",
    "    train_losses += results['train_loss']\n",
    "    test_losses += results['test_loss']\n",
    "    \n",
    "end_time = time.time()\n",
    "wall_time = end_time - start_time\n",
    "\n",
    "np.savetxt(f'./results/kan_w_ge/train_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', train_losses)\n",
    "np.savetxt(f'./results/kan_w_ge/test_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', test_losses)\n",
    "np.savetxt(f'./results/kan_w_ge/walltime_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', [wall_time])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "802a60fe-f65d-4112-ace6-3a3d5ca8fbef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "checkpoint directory created: ./model\n",
      "saving model version 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 1.65e+00 | test_loss: 1.65e+00 | reg: 1.32e+01 | : 100%|█| 1/1 [00:00<00:00,  1.66it/s"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving model version 0.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# KAN (w/o grid extension)\n",
    "\n",
    "from kan import *\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "problem = 1 # 0 or 1\n",
    "opt_name = \"Adam\" #[\"Adam\", \"LBFGS\"]\n",
    "width = 1 # [1,3,10,30,100]\n",
    "depth = 2 # [2,3,4]\n",
    "lrs_lbfgs = [1e-2,3e-2,1e-1,3e-1,1]\n",
    "lrs_adam = [1e-4,3e-4,1e-3,3e-3,1e-2]\n",
    "lr_id = 0 # [0,1,2,3,4]\n",
    "grid = 3 # [3,5,10,20,50,100]\n",
    "\n",
    "if opt_name == \"Adam\":\n",
    "    lr = lrs_adam[lr_id]\n",
    "else:\n",
    "    lr = lrs_lbfgs[lr_id]\n",
    "\n",
    "# create dataset\n",
    "if problem == 0:\n",
    "    f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "    dataset = create_dataset(f, n_var=2, device=device, train_num=1000)\n",
    "    steps = 1 #200\n",
    "else:\n",
    "    def f(x):\n",
    "        return torch.exp(torch.mean(torch.sin(torch.pi/2*x)**2, dim=1))\n",
    "    dataset = create_dataset(f, n_var=100, device=device, train_num=1000)\n",
    "    steps = 1 #10000\n",
    "    \n",
    "\n",
    "\n",
    "#grids = np.array([3,10])\n",
    "\n",
    "\n",
    "train_losses = []\n",
    "test_losses = []\n",
    "k = 3\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "if problem == 0:\n",
    "    model = KAN(width=[2]+(depth-1)*[width]+[1], grid=grid, k=k, seed=0, device=device)\n",
    "else:\n",
    "    model = KAN(width=[100]+(depth-1)*[width]+[1], grid=grid, k=k, seed=0, device=device)\n",
    "\n",
    "results = model.fit(dataset, opt=opt_name, steps=steps, lr=lr)\n",
    "train_losses = results['train_loss']\n",
    "test_losses = results['test_loss']\n",
    "    \n",
    "end_time = time.time()\n",
    "wall_time = end_time - start_time\n",
    "\n",
    "np.savetxt(f'./results/kan_wo_ge/train_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', train_losses)\n",
    "np.savetxt(f'./results/kan_wo_ge/test_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', test_losses)\n",
    "np.savetxt(f'./results/kan_wo_ge/walltime_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', [wall_time])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e843c23d-1c0e-4cc1-96a7-8555de8f4a10",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 1.28e-01 | test_loss: 1.41e-01 | reg: 8.85e+00 | : 100%|█| 10000/10000 [00:27<00:00, 3\n"
     ]
    }
   ],
   "source": [
    "# MLP\n",
    "\n",
    "from kan import *\n",
    "from experiments.baselines.MLP import MLP\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "problem = 1 # 0 or 1\n",
    "opt_name = \"Adam\" #[\"Adam\", \"LBFGS\"]\n",
    "width = 1 # [1,3,10,30,100]\n",
    "depth = 2 # [2,3,4]\n",
    "lrs_lbfgs = [1e-2,3e-2,1e-1,3e-1,1]\n",
    "lrs_adam = [1e-4,3e-4,1e-3,3e-3,1e-2]\n",
    "lr_id = 0 # [0,1,2,3,4]\n",
    "\n",
    "if opt_name == \"Adam\":\n",
    "    lr = lrs_adam[lr_id]\n",
    "else:\n",
    "    lr = lrs_lbfgs[lr_id]\n",
    "\n",
    "# create dataset\n",
    "if problem == 0:\n",
    "    f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "    dataset = create_dataset(f, n_var=2, device=device, train_num=1000)\n",
    "    steps = 1 #200\n",
    "else:\n",
    "    def f(x):\n",
    "        return torch.exp(torch.mean(torch.sin(torch.pi/2*x)**2, dim=1))\n",
    "    dataset = create_dataset(f, n_var=100, device=device, train_num=1000)\n",
    "    steps = 10000 #10000\n",
    "    \n",
    "\n",
    "\n",
    "#grids = np.array([3,10])\n",
    "\n",
    "train_losses = []\n",
    "test_losses = []\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "if problem == 0:\n",
    "    model = MLP(width=[2]+(depth-1)*[width]+[1], seed=0, device=device)\n",
    "else:\n",
    "    model = MLP(width=[100]+(depth-1)*[width]+[1], seed=0, device=device)\n",
    "\n",
    "results = model.fit(dataset, opt=opt_name, steps=steps, lr=lr)\n",
    "train_losses = results['train_loss']\n",
    "test_losses = results['test_loss']\n",
    "    \n",
    "end_time = time.time()\n",
    "wall_time = end_time - start_time\n",
    "\n",
    "np.savetxt(f'./results/mlp/train_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', train_losses)\n",
    "np.savetxt(f'./results/mlp/test_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', test_losses)\n",
    "np.savetxt(f'./results/mlp/walltime_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', [wall_time])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84eee612-bbf2-455b-8413-d3343dd5390f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| train_loss: 7.29e-01 | test_loss: 9.56e-01 | reg: 2.75e+01 | : 100%|█| 10000/10000 [00:29<00:00, 3\n"
     ]
    }
   ],
   "source": [
    "# MLP (random fourier feature)\n",
    "\n",
    "from kan import *\n",
    "from experiments.baselines.MLP import MLP_RFF\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "problem = 1 # 0 or 1\n",
    "opt_name = \"Adam\" #[\"Adam\", \"LBFGS\"]\n",
    "width = 1 # [1,3,10,30,100]\n",
    "depth = 2 # [2,3,4]\n",
    "lrs_lbfgs = [1e-2,3e-2,1e-1,3e-1,1]\n",
    "lrs_adam = [1e-4,3e-4,1e-3,3e-3,1e-2]\n",
    "lr_id = 0 # [0,1,2,3,4]\n",
    "s = 100 #[1,3,10,30,100]\n",
    "\n",
    "if opt_name == \"Adam\":\n",
    "    lr = lrs_adam[lr_id]\n",
    "else:\n",
    "    lr = lrs_lbfgs[lr_id]\n",
    "\n",
    "# create dataset\n",
    "if problem == 0:\n",
    "    f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
    "    dataset = create_dataset(f, n_var=2, device=device, train_num=1000)\n",
    "    steps = 1 #200\n",
    "else:\n",
    "    def f(x):\n",
    "        return torch.exp(torch.mean(torch.sin(torch.pi/2*x)**2, dim=1))\n",
    "    dataset = create_dataset(f, n_var=100, device=device, train_num=1000)\n",
    "    steps = 10000 #10000\n",
    "    \n",
    "\n",
    "\n",
    "#grids = np.array([3,10])\n",
    "\n",
    "train_losses = []\n",
    "test_losses = []\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "if problem == 0:\n",
    "    model = MLP_RFF(width=[2]+(depth-1)*[width]+[1], seed=0, s=s, device=device)\n",
    "else:\n",
    "    model = MLP_RFF(width=[100]+(depth-1)*[width]+[1], seed=0, s=s, device=device)\n",
    "\n",
    "results = model.fit(dataset, opt=opt_name, steps=steps, lr=lr)\n",
    "train_losses = results['train_loss']\n",
    "test_losses = results['test_loss']\n",
    "    \n",
    "end_time = time.time()\n",
    "wall_time = end_time - start_time\n",
    "\n",
    "np.savetxt(f'./results/mlp_rff/train_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', train_losses)\n",
    "np.savetxt(f'./results/mlp_rff/test_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', test_losses)\n",
    "np.savetxt(f'./results/mlp_rff/walltime_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', [wall_time])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9603c477-1424-4c96-8e94-673108387b38",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
