{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b6586d35-4153-4c6c-bd10-790ec50ef3ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1037713/4115845983.py:180: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)\n",
      "/tmp/ipykernel_1037713/4115845983.py:141: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
      "/tmp/ipykernel_1037713/4115845983.py:85: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
      "/tmp/ipykernel_1037713/4115845983.py:53: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True)\n",
      "/tmp/ipykernel_1037713/4115845983.py:349: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration Number = 1000\n",
      "\tIC Loss = 0.00014109905168879777\n",
      "\tBC Loss = 0.0011528183240443468\n",
      "\tPhysics Loss = 0.44464632868766785\n",
      "\tTraining Loss = 0.459909051656723\n",
      "\tRelative L2 error (test) = 32.89075791835785\n",
      "Iteration Number = 2000\n",
      "\tIC Loss = 0.019209494814276695\n",
      "\tBC Loss = 0.11918403208255768\n",
      "\tPhysics Loss = 0.7007512450218201\n",
      "\tTraining Loss = 2.740884780883789\n",
      "\tRelative L2 error (test) = 63.32663297653198\n",
      "Iteration Number = 3000\n",
      "\tIC Loss = 0.000238136388361454\n",
      "\tBC Loss = 0.00013920996570959687\n",
      "\tPhysics Loss = 0.21139909327030182\n",
      "\tTraining Loss = 0.23535194993019104\n",
      "\tRelative L2 error (test) = 19.18889284133911\n",
      "Iteration Number = 4000\n",
      "\tIC Loss = 0.0002580300497356802\n",
      "\tBC Loss = 0.0006793826469220221\n",
      "\tPhysics Loss = 0.3766280710697174\n",
      "\tTraining Loss = 0.40311044454574585\n",
      "\tRelative L2 error (test) = 31.93422257900238\n",
      "Iteration Number = 5000\n",
      "\tIC Loss = 0.04040122777223587\n",
      "\tBC Loss = 0.8722059726715088\n",
      "\tPhysics Loss = 1.190723180770874\n",
      "\tTraining Loss = 6.103052139282227\n",
      "\tRelative L2 error (test) = 77.70264148712158\n",
      "Iteration Number = 6000\n",
      "\tIC Loss = 0.00028686205041594803\n",
      "\tBC Loss = 0.0005772631266154349\n",
      "\tPhysics Loss = 0.386293888092041\n",
      "\tTraining Loss = 0.4155573546886444\n",
      "\tRelative L2 error (test) = 31.93228840827942\n",
      "Iteration Number = 7000\n",
      "\tIC Loss = 0.0007638204842805862\n",
      "\tBC Loss = 0.00030508378404192626\n",
      "\tPhysics Loss = 0.5645246505737305\n",
      "\tTraining Loss = 0.6412118077278137\n",
      "\tRelative L2 error (test) = 40.432122349739075\n",
      "Iteration Number = 8000\n",
      "\tIC Loss = 0.0005630856612697244\n",
      "\tBC Loss = 0.0016345512121915817\n",
      "\tPhysics Loss = 0.5230473279953003\n",
      "\tTraining Loss = 0.5809904336929321\n",
      "\tRelative L2 error (test) = 39.36311900615692\n",
      "Iteration Number = 9000\n",
      "\tIC Loss = 0.0013275862438604236\n",
      "\tBC Loss = 0.0014761356869712472\n",
      "\tPhysics Loss = 0.24279338121414185\n",
      "\tTraining Loss = 0.37702813744544983\n",
      "\tRelative L2 error (test) = 25.55157244205475\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 452\u001b[0m\n\u001b[1;32m    434\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m repeat:\n\u001b[1;32m    435\u001b[0m     pinn \u001b[38;5;241m=\u001b[39m PINN(  k \u001b[38;5;241m=\u001b[39m k,\n\u001b[1;32m    436\u001b[0m                   c \u001b[38;5;241m=\u001b[39m c,\n\u001b[1;32m    437\u001b[0m                   t\u001b[38;5;241m=\u001b[39m t,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    449\u001b[0m                   model_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../models/\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m args\u001b[38;5;241m.\u001b[39mmodel_name \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.model_\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m+\u001b[39margs\u001b[38;5;241m.\u001b[39mmethod\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mlayers)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mNf)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(i),\n\u001b[1;32m    450\u001b[0m                   display_freq \u001b[38;5;241m=\u001b[39m args\u001b[38;5;241m.\u001b[39mdisplay_freq, samp \u001b[38;5;241m=\u001b[39m args\u001b[38;5;241m.\u001b[39mmethod )\n\u001b[0;32m--> 452\u001b[0m     Losses_train, Losses_rel_l2 \u001b[38;5;241m=\u001b[39m pinn\u001b[38;5;241m.\u001b[39mTrain(args\u001b[38;5;241m.\u001b[39mepochs, weights \u001b[38;5;241m=\u001b[39m (\u001b[38;5;241m100\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m)) \u001b[38;5;66;03m# initial, boundary, residual\u001b[39;00m\n\u001b[1;32m    454\u001b[0m     torch\u001b[38;5;241m.\u001b[39msave(Losses_train, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../models/\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m args\u001b[38;5;241m.\u001b[39mmodel_name \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.loss_\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m+\u001b[39margs\u001b[38;5;241m.\u001b[39mmethod\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mlayers)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mNf)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(i))\n\u001b[1;32m    455\u001b[0m     torch\u001b[38;5;241m.\u001b[39msave(Losses_rel_l2, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../models/\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m args\u001b[38;5;241m.\u001b[39mmodel_name \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.rel_l2_\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m+\u001b[39margs\u001b[38;5;241m.\u001b[39mmethod\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mlayers)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(args\u001b[38;5;241m.\u001b[39mNf)\u001b[38;5;241m+\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mstr\u001b[39m(i))\n",
      "Cell \u001b[0;32mIn[5], line 347\u001b[0m, in \u001b[0;36mPINN.Train\u001b[0;34m(self, n_iters, weights)\u001b[0m\n\u001b[1;32m    344\u001b[0m         XTGrid \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(XTGrid, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m    346\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmethod \u001b[38;5;241m==\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrad\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m--> 347\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrad\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_nn)\n\u001b[1;32m    348\u001b[0m         XTGrid \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrad\u001b[38;5;241m.\u001b[39mXTGrid\n\u001b[1;32m    349\u001b[0m         XTGrid \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(XTGrid, dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)   \n",
      "Cell \u001b[0;32mIn[5], line 57\u001b[0m, in \u001b[0;36mRADSampler.update\u001b[0;34m(self, model)\u001b[0m\n\u001b[1;32m     54\u001b[0m XTGrid \u001b[38;5;241m=\u001b[39m XTGrid\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m     56\u001b[0m uf \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mforward(XTGrid)[:,\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 57\u001b[0m uf_x, uf_t \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mgrad(outputs\u001b[38;5;241m=\u001b[39muf\u001b[38;5;241m.\u001b[39mto(device), \n\u001b[1;32m     58\u001b[0m                            inputs\u001b[38;5;241m=\u001b[39mXTGrid, \n\u001b[1;32m     59\u001b[0m                            grad_outputs\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mones(uf\u001b[38;5;241m.\u001b[39mshape)\u001b[38;5;241m.\u001b[39mto(device), \n\u001b[1;32m     60\u001b[0m                            create_graph \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m     61\u001b[0m                            allow_unused\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mT\n\u001b[1;32m     63\u001b[0m uf_xx \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mgrad(outputs\u001b[38;5;241m=\u001b[39muf_x\u001b[38;5;241m.\u001b[39mto(device), \n\u001b[1;32m     64\u001b[0m                            inputs\u001b[38;5;241m=\u001b[39mXTGrid, \n\u001b[1;32m     65\u001b[0m                            grad_outputs\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mones(uf_x\u001b[38;5;241m.\u001b[39mshape)\u001b[38;5;241m.\u001b[39mto(device),\n\u001b[1;32m     66\u001b[0m                            create_graph \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m     67\u001b[0m                            allow_unused\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m][:,\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     69\u001b[0m err \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mabs((uf_t \u001b[38;5;241m+\u001b[39m uf\u001b[38;5;241m*\u001b[39muf_x\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m0.01\u001b[39m\u001b[38;5;241m/\u001b[39mnp\u001b[38;5;241m.\u001b[39mpi\u001b[38;5;241m*\u001b[39muf_xx))\n",
      "File \u001b[0;32m~/anaconda3/envs/dj24/lib/python3.12/site-packages/torch/autograd/__init__.py:412\u001b[0m, in \u001b[0;36mgrad\u001b[0;34m(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)\u001b[0m\n\u001b[1;32m    408\u001b[0m     result \u001b[38;5;241m=\u001b[39m _vmap_internals\u001b[38;5;241m.\u001b[39m_vmap(vjp, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, allow_none_pass_through\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)(\n\u001b[1;32m    409\u001b[0m         grad_outputs_\n\u001b[1;32m    410\u001b[0m     )\n\u001b[1;32m    411\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 412\u001b[0m     result \u001b[38;5;241m=\u001b[39m _engine_run_backward(\n\u001b[1;32m    413\u001b[0m         t_outputs,\n\u001b[1;32m    414\u001b[0m         grad_outputs_,\n\u001b[1;32m    415\u001b[0m         retain_graph,\n\u001b[1;32m    416\u001b[0m         create_graph,\n\u001b[1;32m    417\u001b[0m         inputs,\n\u001b[1;32m    418\u001b[0m         allow_unused,\n\u001b[1;32m    419\u001b[0m         accumulate_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    420\u001b[0m     )\n\u001b[1;32m    421\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m materialize_grads:\n\u001b[1;32m    422\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(\n\u001b[1;32m    423\u001b[0m         result[i] \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_tensor_like(inputs[i])\n\u001b[1;32m    424\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(inputs))\n\u001b[1;32m    425\u001b[0m     ):\n",
      "File \u001b[0;32m~/anaconda3/envs/dj24/lib/python3.12/site-packages/torch/autograd/graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    742\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m    743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 744\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m Variable\u001b[38;5;241m.\u001b[39m_execution_engine\u001b[38;5;241m.\u001b[39mrun_backward(  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    745\u001b[0m         t_outputs, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    746\u001b[0m     )  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    748\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader \n",
    "import numpy as np\n",
    "from scipy import io\n",
    "import matplotlib.pyplot as plt\n",
    "import argparse\n",
    "import os\n",
    "import copy\n",
    "import time\n",
    "\n",
    "def cal_domain_grad(model, XTGrid, device):\n",
    "    XTGrid = XTGrid.to(device)\n",
    "    uf = model.forward(XTGrid)[:,0]\n",
    "    uf_x, uf_t = torch.autograd.grad(outputs=uf.to(device), \n",
    "                               inputs=XTGrid, \n",
    "                               grad_outputs=torch.ones(uf.shape).to(device), \n",
    "                               create_graph = True,\n",
    "                               allow_unused=True)[0].T\n",
    "    \n",
    "    uf_xx = torch.autograd.grad(outputs=uf_x.to(device), \n",
    "                               inputs=XTGrid, \n",
    "                               grad_outputs=torch.ones(uf_x.shape).to(device),\n",
    "                               create_graph = True,\n",
    "                               allow_unused=True)[0][:,0]\n",
    "    \n",
    "    loss =  (uf_t + uf*uf_x-0.01/np.pi*uf_xx)**2\n",
    "\n",
    "    mean_x, mean_t = torch.autograd.grad(outputs=loss.to(device), \n",
    "                               inputs=XTGrid, \n",
    "                               grad_outputs=torch.ones(loss.shape).to(device),\n",
    "                               create_graph = True,\n",
    "                               allow_unused=True)[0].T\n",
    "    grad = torch.concatenate((mean_x.reshape(-1,1), mean_t.reshape(-1,1)), axis = 1)\n",
    "    return grad\n",
    "\n",
    "\n",
    "class RADSampler():\n",
    "    def __init__(self, Nf, device, k, c):    \n",
    "        self.device = device\n",
    "        self.k = k\n",
    "        self.c = c\n",
    "        self.Nf = Nf\n",
    "        self.dense_Nf = Nf*1\n",
    "        \n",
    "    def update(self, model):\n",
    "        \n",
    "        x_new = torch.zeros(self.dense_Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        t_new = torch.zeros(self.dense_Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "        \n",
    "        XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "        XTGrid = XTGrid.to(self.device)\n",
    "        \n",
    "        uf = model.forward(XTGrid)[:,0]\n",
    "        uf_x, uf_t = torch.autograd.grad(outputs=uf.to(device), \n",
    "                                   inputs=XTGrid, \n",
    "                                   grad_outputs=torch.ones(uf.shape).to(device), \n",
    "                                   create_graph = True,\n",
    "                                   allow_unused=True)[0].T\n",
    "\n",
    "        uf_xx = torch.autograd.grad(outputs=uf_x.to(device), \n",
    "                                   inputs=XTGrid, \n",
    "                                   grad_outputs=torch.ones(uf_x.shape).to(device),\n",
    "                                   create_graph = True,\n",
    "                                   allow_unused=True)[0][:,0]\n",
    "        \n",
    "        err = torch.abs((uf_t + uf*uf_x-0.01/np.pi*uf_xx))\n",
    "        err = (err**self.k)/((err**self.k).mean())+self.c\n",
    "        err_norm = err/(err.sum())\n",
    "        \n",
    "        indice = torch.multinomial(err_norm, self.Nf, replacement = True)\n",
    "        XTGrid = XTGrid[indice]\n",
    "        self.XTGrid = XTGrid\n",
    "        \n",
    "class LASSampler():\n",
    "    def __init__(self, Nf, fixed_uniform, device, L_iter = 1, beta = 0.2, tau = 0.002):\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.cnt = 0\n",
    "        self.beta = beta\n",
    "        self.tau = tau\n",
    "        self.L_iter = L_iter\n",
    "        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "\n",
    "    def update(self, phy_lf, model):\n",
    "\n",
    "        # x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        # t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        # x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "        # self.XTGrid = x_t_new\n",
    "\n",
    "        x_data = self.XTGrid\n",
    "        samples = x_data.clone().detach().requires_grad_(True)\n",
    "        \n",
    "        for t in range(1, self.L_iter + 1):\n",
    "            grad = phy_lf(model, samples, self.device)\n",
    "            scaler = torch.sqrt(torch.sum((grad+1e-16)**2, axis = 1)).reshape(-1,1)\n",
    "            grad = grad/scaler\n",
    "            with torch.no_grad():\n",
    "                samples = samples + self.tau * grad + self.beta*torch.sqrt(torch.tensor(2 * self.tau, device=self.device)) * torch.randn(samples.shape, device=self.device)\n",
    "                samples[:, 0] = torch.clamp(samples[:, 0], min=-1, max=1) \n",
    "                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=1)   \n",
    "            samples = samples.clone().detach().requires_grad_(True)\n",
    "        self.XTGrid = samples.detach()\n",
    "\n",
    "class L_INFSampler():\n",
    "    def __init__(self, Nf, device, step_size = 0.05 , n_iter = 20):\n",
    "\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.step_size = step_size\n",
    "        self.n_iter = n_iter\n",
    "\n",
    "    def update(self, phy_lf, model):\n",
    "\n",
    "        x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "        self.XTGrid = x_t_new\n",
    "            \n",
    "        x_data = self.XTGrid\n",
    "        samples = x_data.clone().detach().requires_grad_(True)\n",
    "    \n",
    "        for t in range(1, self.n_iter + 1):\n",
    "            grad = phy_lf(model, samples, self.device)\n",
    "            with torch.no_grad():\n",
    "                samples = samples + self.step_size * torch.sign(grad)\n",
    "                samples[:, 0] = torch.clamp(samples[:, 0], min=-1, max=1)  \n",
    "                samples[:, 1] = torch.clamp(samples[:, 1], min=0, max=1)   \n",
    "            samples = samples.clone().detach().requires_grad_(True)\n",
    "        self.XTGrid = samples.detach()        \n",
    "\n",
    "\n",
    "class R3Sampler(nn.Module):\n",
    "    def __init__(self,Nf, fixed_uniform, device):\n",
    "        super(R3Sampler, self).__init__()\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "    \n",
    "    def update(self, loss_aver, loss_ele):\n",
    "        with torch.no_grad():\n",
    "            cho_i = loss_ele > loss_aver\n",
    "            cho_i = cho_i.to('cpu')\n",
    "            self.XTGrid = self.XTGrid[cho_i].detach()\n",
    "            need_n_sample = self.Nf-self.XTGrid.shape[0]\n",
    "            x_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "            t_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "            x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "            self.XTGrid = torch.concatenate((self.XTGrid, x_t_new), axis = 0)\n",
    "            self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True)\n",
    "    \n",
    "class PINN(nn.Module):\n",
    "    def __init__(self,k , c , t, X_star, u_star, exact_u, space_domain, time_domain, Layers, N0, Nb, Nf, \n",
    "                 Activation = nn.Tanh(), \n",
    "                 model_name = \"PINN.model\", device = 'cpu',\n",
    "                  display_freq = 100, samp = 'fixed' ):\n",
    "        \n",
    "        super(PINN, self).__init__()\n",
    "        \n",
    "        \n",
    "        LBs = [space_domain[0], time_domain[0]]\n",
    "        UBs = [space_domain[1], time_domain[1]]\n",
    "        \n",
    "        self.LBs = torch.tensor(LBs, dtype=torch.float32).to(device)\n",
    "        self.UBs = torch.tensor(UBs, dtype=torch.float32).to(device)\n",
    "        \n",
    "        self.Layers = Layers\n",
    "        self.in_dim  = Layers[0]\n",
    "        self.out_dim = Layers[-1]\n",
    "        self.Activation = Activation\n",
    "        \n",
    "        self.device = device\n",
    "        \n",
    "        x_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        t_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        x_t_init = torch.concatenate((x_init,t_init), axis = 1)\n",
    "        self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)\n",
    "        \n",
    "        self.N0 = N0\n",
    "        self.Nb = Nb\n",
    "        self.Nf = Nf\n",
    "        \n",
    "        self.t = t\n",
    "        self.X_star = X_star\n",
    "        self.u_star = u_star\n",
    "        self.exact_u = exact_u\n",
    "        \n",
    "        self.XT0, self.u0  = self.InitialCondition(self.LBs[0], self.UBs[0])\n",
    "        self.XTbL, self.XTbU = self.BoundaryCondition( self.LBs[0], self.UBs[0])\n",
    "        \n",
    "        self.XT0 = self.XT0.to(device)\n",
    "        self.u0 = self.u0.to(device) \n",
    "        \n",
    "        self.XTbL = self.XTbL.to(device) \n",
    "        self.XTbU = self.XTbU.to(device)\n",
    "        \n",
    "        self._nn = self.build_model()\n",
    "        self._nn.to(self.device)\n",
    "        self.Loss = torch.nn.MSELoss(reduction='mean')\n",
    "        \n",
    "        self.model_name = model_name\n",
    "        self.display_freq = display_freq\n",
    "        \n",
    "        self.k = k\n",
    "        self.c = c\n",
    "\n",
    "        self.method = samp\n",
    "        \n",
    "        self.r3_sample = R3Sampler(self.Nf, self.fixed_uniform, device)\n",
    "        self.las = LASSampler(self.Nf, fixed_uniform=self.fixed_uniform, device=self.device, L_iter = 1, beta = 0.2, tau=2e-3)\n",
    "        self.l_inf = L_INFSampler(self.Nf, device=self.device, step_size = 0.05 , n_iter = 20)\n",
    "        self.rad = RADSampler(Nf = self.Nf, device=self.device, k = self.k, c=self.c)\n",
    "        \n",
    "    \n",
    "    def build_model(self):\n",
    "        Seq = nn.Sequential()\n",
    "        for ii in range(len(self.Layers)-1):\n",
    "            this_module = nn.Linear(self.Layers[ii], self.Layers[ii+1])\n",
    "            nn.init.xavier_normal_(this_module.weight)\n",
    "            Seq.add_module(\"Linear\" + str(ii), this_module)\n",
    "            if not ii == len(self.Layers)-2:\n",
    "                Seq.add_module(\"Activation\" + str(ii), self.Activation)\n",
    "        return Seq\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x.to(self.device)\n",
    "        x = x.reshape((-1,self.in_dim))  \n",
    "        return torch.reshape(self._nn.forward(x), (-1, self.out_dim))\n",
    "\n",
    "    def InitialCondition(self,LB, UB):\n",
    "        x = torch.tensor([])\n",
    "\n",
    "        if (type(LB) != type(x)):\n",
    "          LB = torch.tensor(LB).cpu()\n",
    "        else:\n",
    "          LB = LB.cpu()\n",
    "        if (type(UB) != type(x)):\n",
    "          UB = torch.tensor(UB).cpu()\n",
    "        else:\n",
    "          UB = UB.cpu()\n",
    "\n",
    "        indices = (self.X_star[:,0] >= LB) & (self.X_star[:,0] < UB) & (self.X_star[:,1] == 0.)\n",
    "        XT0 = self.X_star[indices]\n",
    "        u0 = self.u_star[indices]\n",
    "\n",
    "        return XT0, u0\n",
    "\n",
    "    def BoundaryCondition(self, LB, UB):\n",
    "        x = torch.tensor([])\n",
    "        \n",
    "        if (type(LB) != type(x)):\n",
    "          LB = torch.tensor(LB).cpu()\n",
    "        else:\n",
    "          LB = LB.cpu()\n",
    "        if (type(UB) != type(x)):\n",
    "          UB = torch.tensor(UB).cpu()\n",
    "        else:\n",
    "          UB = UB.cpu()\n",
    "        \n",
    "        tb =  torch.tensor(np.linspace(0, 1, self.t.shape[0], endpoint=False), dtype = torch.float32)\n",
    "        XTL = torch.cat(( LB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)\n",
    "        XTL.requires_grad_()\n",
    "        XTU = torch.cat(( UB*torch.ones((self.t.shape[0],1)), tb.reshape(-1,1)), dim = 1)\n",
    "        XTU.requires_grad_()\n",
    "        \n",
    "        return  XTL, XTU\n",
    "    \n",
    "    def ICLoss(self):\n",
    "        XT0 = self.XT0\n",
    "        u0  = self.u0\n",
    "        UV0_pred = self.forward(XT0)\n",
    "        u0_pred = UV0_pred[:,0].reshape(-1)\n",
    "        return self.Loss(u0_pred, u0)\n",
    "\n",
    "    def BCLoss(self):\n",
    "        ub_l = self.forward(self.XTbL)\n",
    "        ub_u = self.forward(self.XTbU)\n",
    "            \n",
    "        return torch.mean(ub_l**2+ub_u**2)\n",
    "    \n",
    "    def PhysicsLoss(self, XTGrid):\n",
    "        XTGrid = XTGrid.to(self.device)\n",
    "        uf = self.forward(XTGrid)[:,0]\n",
    "        uf_x, uf_t = torch.autograd.grad(outputs=uf.to(self.device), \n",
    "                                   inputs=XTGrid, \n",
    "                                   grad_outputs=torch.ones(uf.shape).to(self.device), \n",
    "                                   create_graph = True,\n",
    "                                   allow_unused=True)[0].T\n",
    "        uf_xx = torch.autograd.grad(outputs=uf_x.to(self.device), \n",
    "                                   inputs=XTGrid, \n",
    "                                   grad_outputs=torch.ones(uf_x.shape).to(self.device),\n",
    "                                   create_graph = True,\n",
    "                                   allow_unused=True)[0][:,0]\n",
    "        loss2 =  (uf_t + uf*uf_x-0.01/np.pi*uf_xx)**2\n",
    "        loss1 = loss2.mean()\n",
    "        \n",
    "        return loss1, loss2 \n",
    "    \n",
    "\n",
    "    def Train(self, n_iters, weights=(1.0,1.0,1.0)):\n",
    "        params = list(self.parameters())\n",
    "        optimizer = optim.Adam(params, lr=1e-3)\n",
    "        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 5000, gamma=0.9, last_epoch=-1)\n",
    "        min_loss = 999999.0\n",
    "        Training_Losses = [-10]*n_iters\n",
    "        Test_Losses = []\n",
    "        rel_error = [-10]*(1+n_iters//1000)\n",
    "        \n",
    "        for jj in range(n_iters):\n",
    "            Total_ICLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            Total_BCLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            Total_PhysicsLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            \n",
    "            Total_ICLoss = Total_ICLoss + self.ICLoss()\n",
    "            Total_BCLoss = Total_BCLoss + self.BCLoss()\n",
    "            \n",
    "            if self.method =='r3':\n",
    "                if jj == 0:\n",
    "                    XTGrid = self.r3_sample.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                else:\n",
    "                    with torch.no_grad():\n",
    "                        self.r3_sample.update(loss1, loss2)\n",
    "                        XTGrid = self.r3_sample.XTGrid\n",
    "                        XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                        \n",
    "            elif self.method == 'las':\n",
    "                if jj == 0:\n",
    "                    XTGrid = self.las.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                else:\n",
    "                    if self.las.cnt % 1 == 0:# 4,6,8,10 cnt = 4, \n",
    "                        self.las.update(cal_domain_grad, self._nn)\n",
    "                    XTGrid = self.las.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                self.las.cnt += 1\n",
    "            \n",
    "            elif self.method =='l_inf':\n",
    "                    self.l_inf.update(cal_domain_grad, self._nn)\n",
    "                    XTGrid = self.l_inf.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                \n",
    "            elif self.method =='rad':\n",
    "                    self.rad.update(self._nn)\n",
    "                    XTGrid = self.rad.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)   \n",
    "            \n",
    "            elif self.method =='fixed':\n",
    "                    XTGrid = torch.tensor(self.fixed_uniform, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "            \n",
    "            elif self.method =='random-r':\n",
    "                    x_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "                    t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "                    x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "                    XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                \n",
    "            optimizer.zero_grad()    \n",
    "            loss1, loss2 = self.PhysicsLoss(XTGrid) # For r3 method, loss2 contains element-wise errors\n",
    "            \n",
    "            Total_PhysicsLoss = Total_PhysicsLoss + loss1\n",
    "            Total_Loss = weights[0]*Total_ICLoss + weights[1]*Total_BCLoss\\\n",
    "                        + weights[2]*Total_PhysicsLoss \n",
    "            \n",
    "            Total_Loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            if Total_Loss < min_loss:\n",
    "                torch.save(self._nn.state_dict(), \"../models/\"+self.method+'_'+str(len(self.Layers)-2)+'_'+str(self.Nf)+'.pt')\n",
    "                min_loss = float(Total_Loss)\n",
    "                    \n",
    "            Training_Losses[jj] = float(Total_Loss)\n",
    "            \n",
    "            if (jj+1) % self.display_freq == 0:\n",
    "                with torch.no_grad():\n",
    "                    outputs = self.forward(X_star)\n",
    "                    outputs = outputs.reshape(100, 256)\n",
    "                    re = np.linalg.norm(Exact_u.cpu().T-outputs.cpu().detach()) / np.linalg.norm(Exact_u.cpu().detach().T)\n",
    "                    rel_error[int((jj+1)/1000)] = float(re*100)\n",
    "                print(\"Iteration Number = {}\".format(jj+1))\n",
    "                print(\"\\tIC Loss = {}\".format(float(Total_ICLoss)))\n",
    "                print(\"\\tBC Loss = {}\".format(float(Total_BCLoss)))\n",
    "                print(\"\\tPhysics Loss = {}\".format(float(Total_PhysicsLoss)))\n",
    "                print(\"\\tTraining Loss = {}\".format(float(Total_Loss)))\n",
    "                print(\"\\tRelative L2 error (test) = {}\".format(float(re*100)))\n",
    "                # torch.save(XTGrid, \"../models/\"+self.method +'/'+str(self.Nf)+\"_grid_\"+str(jj+1))\n",
    "                # torch.save(Exact_u.cpu().T-outputs.cpu().detach(), \"../models/\"+self.method +'/'+str(self.Nf)+\"_error_\"+str(jj+1))\n",
    "\n",
    "        return Training_Losses, rel_error\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "                \n",
    "        parser = argparse.ArgumentParser()\n",
    "        parser.add_argument('--nodes', type=int, default = 128, help='The number of nodes per hidden layer in the neural network')\n",
    "        parser.add_argument('--layers', type=int, default = 8, help='The number of hidden layers in the neural network')\n",
    "        parser.add_argument('--N0', type=int, default = 100, help='The number of points to use on the initial condition')\n",
    "        parser.add_argument('--Nb', type=int, default = 100, help='The number of points to use on the boundary condition')\n",
    "        parser.add_argument('--Nf', type=int, default = 1000, help='The number of collocation points to use')\n",
    "        parser.add_argument('--epochs', type=int, default = 200000, help='The number of epochs to train the neural network')\n",
    "        parser.add_argument('--method', type=str, default='rad', help='Sampling method') # fixed, random-r, rad, r3, l_inf, las \n",
    "        parser.add_argument('--model-name', type=str, default='PINN_model', help='File name to save the model')\n",
    "        parser.add_argument('--display-freq', type=int, default=1000, help='How often to display loss information')\n",
    "        parser.add_argument('-f')\n",
    "        args = parser.parse_args()\n",
    "\n",
    "        data = io.loadmat('../data/burgers_shock.mat')\n",
    "        t = torch.tensor(data['t'], dtype = torch.float32) \n",
    "        x = torch.tensor(data['x'], dtype = torch.float32) \n",
    "        Exact_u = torch.tensor(data['usol'], dtype = torch.float32) \n",
    "        \n",
    "        X, T = np.meshgrid(x,t)\n",
    "        X_star = torch.tensor(np.hstack((X.flatten()[:,None], T.flatten()[:,None])), dtype = torch.float32)\n",
    "        u_star = torch.flatten(torch.transpose(Exact_u,0,1))\n",
    "        \n",
    "        if not os.path.exists(\"../models/\"):\n",
    "            os.mkdir(\"../models/\")\n",
    "\n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        NHiddenLayers = args.layers\n",
    "        \n",
    "        boundaries = [-1, 1]\n",
    "        t_domain = [0., 1.]\n",
    "        \n",
    "        Layers = [2] + [args.nodes]*NHiddenLayers + [1]\n",
    "        Activation = nn.Tanh()\n",
    "\n",
    "        k = 1\n",
    "        c = 1\n",
    "\n",
    "        repeat = [0, 1, 2, 3, 4]\n",
    "        for i in repeat:\n",
    "            pinn = PINN(  k = k,\n",
    "                          c = c,\n",
    "                          t= t,\n",
    "                          X_star = X_star,\n",
    "                          u_star = u_star,\n",
    "                          exact_u = Exact_u,\n",
    "                          space_domain = boundaries,\n",
    "                          time_domain = t_domain,\n",
    "                          Layers = Layers,\n",
    "                          N0 = args.N0,\n",
    "                          Nb = args.Nb,\n",
    "                          Nf = args.Nf,\n",
    "                          Activation = Activation,\n",
    "                          device = device,\n",
    "                          model_name = \"../models/\" + args.model_name + \".model_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i),\n",
    "                          display_freq = args.display_freq, samp = args.method )\n",
    "\n",
    "            Losses_train, Losses_rel_l2 = pinn.Train(args.epochs, weights = (100, 1, 1)) # initial, boundary, residual\n",
    "\n",
    "            torch.save(Losses_train, \"../models/\" + args.model_name + \".loss_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i))\n",
    "            torch.save(Losses_rel_l2, \"../models/\" + args.model_name + \".rel_l2_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
