{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "04dae6e4-0011-49c0-bdcf-1026bf5d2728",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n=1\n",
      "cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pde loss: 9.36e+01 | bc loss: 2.08e-03 | l2: 2.39e-01 : 100%|█████████| 1/1 [00:00<00:00,  2.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n=2\n",
      "cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pde loss: 1.50e+03 | bc loss: 2.08e-03 | l2: 2.40e-01 : 100%|█████████| 1/1 [00:00<00:00, 81.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n=4\n",
      "cuda\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "pde loss: 2.40e+04 | bc loss: 2.08e-03 | l2: 2.40e-01 : 100%|█████████| 1/1 [00:00<00:00, 82.60it/s]\n"
     ]
    }
   ],
   "source": [
    "from kan import *\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from torch import autograd\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "from experiments.baselines.MLP import MLP\n",
    "\n",
    "\n",
    "# high frequency Poisson\n",
    "\n",
    "ns = [1,2,4]\n",
    "\n",
    "steps = 1\n",
    "\n",
    "for n in ns:\n",
    "    start_time = time.time()\n",
    "\n",
    "    print(f'n={n}')\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(device)\n",
    "\n",
    "    dim = 2\n",
    "    np_i = 51 # number of interior points (along each dimension)\n",
    "    np_b = 51 # number of boundary points (along each dimension)\n",
    "    ranges = [-1, 1]\n",
    "\n",
    "\n",
    "\n",
    "    def batch_jacobian(func, x, create_graph=False):\n",
    "        # x in shape (Batch, Length)\n",
    "        def _func_sum(x):\n",
    "            return func(x).sum(dim=0)\n",
    "        return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)\n",
    "\n",
    "    # define solution\n",
    "    sol_fun = lambda x: torch.sin(n*torch.pi*x[:,[0]])*torch.sin(n*torch.pi*x[:,[1]])\n",
    "    source_fun = lambda x: -2*(n*torch.pi)**2 * torch.sin(n*torch.pi*x[:,[0]])*torch.sin(n*torch.pi*x[:,[1]])\n",
    "\n",
    "    # interior\n",
    "    sampling_mode = 'mesh' # 'radnom' or 'mesh'\n",
    "\n",
    "    x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
    "    y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
    "    X, Y = torch.meshgrid(x_mesh, y_mesh, indexing=\"ij\")\n",
    "    if sampling_mode == 'mesh':\n",
    "        #mesh\n",
    "        x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)\n",
    "    else:\n",
    "        #random\n",
    "        x_i = torch.rand((np_i**2,2))*2-1\n",
    "\n",
    "    x_i = x_i.to(device)\n",
    "\n",
    "    # boundary, 4 sides\n",
    "    helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)\n",
    "    xb1 = helper(X[0], Y[0])\n",
    "    xb2 = helper(X[-1], Y[0])\n",
    "    xb3 = helper(X[:,0], Y[:,0])\n",
    "    xb4 = helper(X[:,0], Y[:,-1])\n",
    "    x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)\n",
    "\n",
    "    x_b = x_b.to(device)\n",
    "\n",
    "    alpha = 0.01\n",
    "    log = 1\n",
    "\n",
    "    pde_losses = []\n",
    "    bc_losses = []\n",
    "    l2_losses = []\n",
    "\n",
    "\n",
    "    model = MLP(width=[2,128,128,128,1], seed=1, device=device)\n",
    "\n",
    "\n",
    "    def train():\n",
    "        #optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
    "\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "        pbar = tqdm(range(steps), desc='description', ncols=100)\n",
    "\n",
    "        for _ in pbar:\n",
    "            def closure():\n",
    "                global pde_loss, bc_loss\n",
    "                optimizer.zero_grad()\n",
    "                # interior loss\n",
    "                sol = sol_fun(x_i)\n",
    "                sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]\n",
    "                sol_D1 = sol_D1_fun(x_i)\n",
    "                sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]\n",
    "                lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)\n",
    "                source = source_fun(x_i)\n",
    "                pde_loss = torch.mean((lap - source)**2)\n",
    "\n",
    "                # boundary loss\n",
    "                bc_true = sol_fun(x_b)\n",
    "                bc_pred = model(x_b)\n",
    "                bc_loss = torch.mean((bc_pred-bc_true)**2)\n",
    "\n",
    "                loss = alpha * pde_loss + bc_loss\n",
    "                loss.backward()\n",
    "                return loss\n",
    "\n",
    "            optimizer.step(closure)\n",
    "            sol = sol_fun(x_i)\n",
    "            loss = alpha * pde_loss + bc_loss\n",
    "            l2 = torch.mean((model(x_i) - sol)**2)\n",
    "\n",
    "            if _ % log == 0:\n",
    "                pbar.set_description(\"pde loss: %.2e | bc loss: %.2e | l2: %.2e \" % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.cpu().detach().numpy()))\n",
    "\n",
    "            pde_losses.append(pde_loss.cpu().detach().numpy())\n",
    "            bc_losses.append(bc_loss.cpu().detach().numpy())\n",
    "            l2_losses.append(l2.cpu().detach().numpy())\n",
    "\n",
    "\n",
    "    train()\n",
    "\n",
    "    end_time = time.time()\n",
    "\n",
    "    np_i = 201\n",
    "    x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
    "    y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
    "    X, Y = torch.meshgrid(x_mesh, y_mesh, indexing=\"ij\")\n",
    "    x_i_show = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)\n",
    "\n",
    "    np.savetxt(f'./results/mlp_sol_n_{n}_steps_{steps}', model(x_i_show.to(device)).reshape(np_i, np_i).cpu().detach().numpy())\n",
    "    np.savetxt(f'./results/mlp_walltime_n_{n}_steps_{steps}', [end_time - start_time])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52f43ef3-8263-4417-85b5-fbf2ff0f1c7d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
