{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "528d73c4-3f68-4f33-b856-b058bb4542cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1105682/2056753709.py:248: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)\n",
      "/tmp/ipykernel_1105682/2056753709.py:212: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
      "/tmp/ipykernel_1105682/2056753709.py:151: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
      "/tmp/ipykernel_1105682/2056753709.py:433: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
      "/home/user/anaconda3/envs/dj24/lib/python3.12/site-packages/torch/autograd/graph.py:744: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n",
      "/tmp/ipykernel_1105682/2056753709.py:438: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration Number = 10\n",
      "\tIC Loss = 1.174581527709961\n",
      "\tBC Loss = 15.379281044006348\n",
      "\tPhysics Loss = 12282.38671875\n",
      "\tTraining Loss = 12298.9404296875\n",
      "\tRelative L2 error (test) = 236.69915199279785\n",
      "Iteration Number = 20\n",
      "\tIC Loss = 1.058443546295166\n",
      "\tBC Loss = 11.960921287536621\n",
      "\tPhysics Loss = 12787.9091796875\n",
      "\tTraining Loss = 12800.9287109375\n",
      "\tRelative L2 error (test) = 200.835919380188\n",
      "Iteration Number = 30\n",
      "\tIC Loss = 0.942874014377594\n",
      "\tBC Loss = 9.681939125061035\n",
      "\tPhysics Loss = 13072.5302734375\n",
      "\tTraining Loss = 13083.1552734375\n",
      "\tRelative L2 error (test) = 180.66638708114624\n",
      "Iteration Number = 40\n",
      "\tIC Loss = 1.0546263456344604\n",
      "\tBC Loss = 15.70188045501709\n",
      "\tPhysics Loss = 13749.173828125\n",
      "\tTraining Loss = 13765.9306640625\n",
      "\tRelative L2 error (test) = 191.39121770858765\n",
      "Iteration Number = 50\n",
      "\tIC Loss = 2.2192153930664062\n",
      "\tBC Loss = 22.06633758544922\n",
      "\tPhysics Loss = 14988.248046875\n",
      "\tTraining Loss = 15012.533203125\n",
      "\tRelative L2 error (test) = 210.18071174621582\n",
      "Iteration Number = 60\n",
      "\tIC Loss = 1.5488046407699585\n",
      "\tBC Loss = 18.591524124145508\n",
      "\tPhysics Loss = 14772.83984375\n",
      "\tTraining Loss = 14792.98046875\n",
      "\tRelative L2 error (test) = 225.21896362304688\n",
      "Iteration Number = 70\n",
      "\tIC Loss = 1.7062350511550903\n",
      "\tBC Loss = 21.475418090820312\n",
      "\tPhysics Loss = 15418.8076171875\n",
      "\tTraining Loss = 15441.9892578125\n",
      "\tRelative L2 error (test) = 239.186692237854\n",
      "Iteration Number = 80\n",
      "\tIC Loss = 2.836808681488037\n",
      "\tBC Loss = 34.97604751586914\n",
      "\tPhysics Loss = 14997.609375\n",
      "\tTraining Loss = 15035.421875\n",
      "\tRelative L2 error (test) = 301.53470039367676\n",
      "Iteration Number = 90\n",
      "\tIC Loss = 1.2185337543487549\n",
      "\tBC Loss = 13.089399337768555\n",
      "\tPhysics Loss = 15756.6708984375\n",
      "\tTraining Loss = 15770.978515625\n",
      "\tRelative L2 error (test) = 154.4187307357788\n",
      "Iteration Number = 100\n",
      "\tIC Loss = 1.1450111865997314\n",
      "\tBC Loss = 11.039051055908203\n",
      "\tPhysics Loss = 16087.88671875\n",
      "\tTraining Loss = 16100.0703125\n",
      "\tRelative L2 error (test) = 154.53120470046997\n",
      "Iteration Number = 110\n",
      "\tIC Loss = 1.1852152347564697\n",
      "\tBC Loss = 12.71418571472168\n",
      "\tPhysics Loss = 16696.517578125\n",
      "\tTraining Loss = 16710.416015625\n",
      "\tRelative L2 error (test) = 166.72742366790771\n",
      "Iteration Number = 120\n",
      "\tIC Loss = 1.6616528034210205\n",
      "\tBC Loss = 23.286603927612305\n",
      "\tPhysics Loss = 16778.494140625\n",
      "\tTraining Loss = 16803.443359375\n",
      "\tRelative L2 error (test) = 238.69125843048096\n",
      "Iteration Number = 130\n",
      "\tIC Loss = 1.2238078117370605\n",
      "\tBC Loss = 13.07340145111084\n",
      "\tPhysics Loss = 16785.845703125\n",
      "\tTraining Loss = 16800.142578125\n",
      "\tRelative L2 error (test) = 173.22516441345215\n",
      "Iteration Number = 140\n",
      "\tIC Loss = 1.7462430000305176\n",
      "\tBC Loss = 20.12973403930664\n",
      "\tPhysics Loss = 16599.79296875\n",
      "\tTraining Loss = 16621.66796875\n",
      "\tRelative L2 error (test) = 200.61564445495605\n",
      "Iteration Number = 150\n",
      "\tIC Loss = 1.6447694301605225\n",
      "\tBC Loss = 18.805086135864258\n",
      "\tPhysics Loss = 16987.109375\n",
      "\tTraining Loss = 17007.55859375\n",
      "\tRelative L2 error (test) = 195.83898782730103\n",
      "Iteration Number = 160\n",
      "\tIC Loss = 1.8977382183074951\n",
      "\tBC Loss = 20.807273864746094\n",
      "\tPhysics Loss = 17113.921875\n",
      "\tTraining Loss = 17136.626953125\n",
      "\tRelative L2 error (test) = 201.42066478729248\n",
      "Iteration Number = 170\n",
      "\tIC Loss = 1.5159810781478882\n",
      "\tBC Loss = 14.032976150512695\n",
      "\tPhysics Loss = 17291.720703125\n",
      "\tTraining Loss = 17307.26953125\n",
      "\tRelative L2 error (test) = 171.64252996444702\n",
      "Iteration Number = 180\n",
      "\tIC Loss = 1.2844312191009521\n",
      "\tBC Loss = 10.539374351501465\n",
      "\tPhysics Loss = 18081.6015625\n",
      "\tTraining Loss = 18093.42578125\n",
      "\tRelative L2 error (test) = 141.01240634918213\n",
      "Iteration Number = 190\n",
      "\tIC Loss = 1.5066149234771729\n",
      "\tBC Loss = 19.291078567504883\n",
      "\tPhysics Loss = 18322.041015625\n",
      "\tTraining Loss = 18342.837890625\n",
      "\tRelative L2 error (test) = 223.54578971862793\n",
      "Iteration Number = 200\n",
      "\tIC Loss = 1.6509650945663452\n",
      "\tBC Loss = 21.902647018432617\n",
      "\tPhysics Loss = 18171.98046875\n",
      "\tTraining Loss = 18195.533203125\n",
      "\tRelative L2 error (test) = 220.20807266235352\n",
      "Iteration Number = 210\n",
      "\tIC Loss = 1.2385976314544678\n",
      "\tBC Loss = 16.493349075317383\n",
      "\tPhysics Loss = 18327.564453125\n",
      "\tTraining Loss = 18345.296875\n",
      "\tRelative L2 error (test) = 202.6618480682373\n",
      "Iteration Number = 220\n",
      "\tIC Loss = 1.7080349922180176\n",
      "\tBC Loss = 21.746122360229492\n",
      "\tPhysics Loss = 18276.8984375\n",
      "\tTraining Loss = 18300.353515625\n",
      "\tRelative L2 error (test) = 224.2788553237915\n",
      "Iteration Number = 230\n",
      "\tIC Loss = 2.352725028991699\n",
      "\tBC Loss = 28.075027465820312\n",
      "\tPhysics Loss = 18365.513671875\n",
      "\tTraining Loss = 18395.94140625\n",
      "\tRelative L2 error (test) = 255.6931495666504\n",
      "Iteration Number = 240\n",
      "\tIC Loss = 1.9951984882354736\n",
      "\tBC Loss = 18.256086349487305\n",
      "\tPhysics Loss = 18543.5625\n",
      "\tTraining Loss = 18563.814453125\n",
      "\tRelative L2 error (test) = 194.9471116065979\n",
      "Iteration Number = 250\n",
      "\tIC Loss = 2.3120884895324707\n",
      "\tBC Loss = 27.018007278442383\n",
      "\tPhysics Loss = 18223.92578125\n",
      "\tTraining Loss = 18253.255859375\n",
      "\tRelative L2 error (test) = 267.08762645721436\n",
      "Iteration Number = 260\n",
      "\tIC Loss = 3.952500343322754\n",
      "\tBC Loss = 46.706356048583984\n",
      "\tPhysics Loss = 17890.376953125\n",
      "\tTraining Loss = 17941.03515625\n",
      "\tRelative L2 error (test) = 309.2762231826782\n",
      "Iteration Number = 270\n",
      "\tIC Loss = 4.002389430999756\n",
      "\tBC Loss = 50.23784637451172\n",
      "\tPhysics Loss = 17502.09765625\n",
      "\tTraining Loss = 17556.337890625\n",
      "\tRelative L2 error (test) = 338.4568452835083\n",
      "Iteration Number = 280\n",
      "\tIC Loss = 3.7709929943084717\n",
      "\tBC Loss = 48.80746841430664\n",
      "\tPhysics Loss = 18173.212890625\n",
      "\tTraining Loss = 18225.791015625\n",
      "\tRelative L2 error (test) = 332.76867866516113\n",
      "Iteration Number = 290\n",
      "\tIC Loss = 3.376677989959717\n",
      "\tBC Loss = 44.24705505371094\n",
      "\tPhysics Loss = 18341.12890625\n",
      "\tTraining Loss = 18388.751953125\n",
      "\tRelative L2 error (test) = 333.278751373291\n",
      "Iteration Number = 300\n",
      "\tIC Loss = 2.8054280281066895\n",
      "\tBC Loss = 32.29450607299805\n",
      "\tPhysics Loss = 18333.984375\n",
      "\tTraining Loss = 18369.083984375\n",
      "\tRelative L2 error (test) = 272.35052585601807\n",
      "Iteration Number = 310\n",
      "\tIC Loss = 2.576961040496826\n",
      "\tBC Loss = 30.281827926635742\n",
      "\tPhysics Loss = 18512.681640625\n",
      "\tTraining Loss = 18545.541015625\n",
      "\tRelative L2 error (test) = 282.5489044189453\n",
      "Iteration Number = 320\n",
      "\tIC Loss = 4.792396545410156\n",
      "\tBC Loss = 53.22669982910156\n",
      "\tPhysics Loss = 18202.560546875\n",
      "\tTraining Loss = 18260.580078125\n",
      "\tRelative L2 error (test) = 335.77868938446045\n",
      "Iteration Number = 330\n",
      "\tIC Loss = 4.664791584014893\n",
      "\tBC Loss = 58.72222900390625\n",
      "\tPhysics Loss = 18208.12109375\n",
      "\tTraining Loss = 18271.5078125\n",
      "\tRelative L2 error (test) = 372.3904609680176\n",
      "Iteration Number = 340\n",
      "\tIC Loss = 3.989516019821167\n",
      "\tBC Loss = 42.45028305053711\n",
      "\tPhysics Loss = 18808.064453125\n",
      "\tTraining Loss = 18854.50390625\n",
      "\tRelative L2 error (test) = 297.38874435424805\n",
      "Iteration Number = 350\n",
      "\tIC Loss = 2.9312100410461426\n",
      "\tBC Loss = 38.40785217285156\n",
      "\tPhysics Loss = 18730.0703125\n",
      "\tTraining Loss = 18771.41015625\n",
      "\tRelative L2 error (test) = 321.2646245956421\n",
      "Iteration Number = 360\n",
      "\tIC Loss = 2.9193596839904785\n",
      "\tBC Loss = 37.91427993774414\n",
      "\tPhysics Loss = 19061.876953125\n",
      "\tTraining Loss = 19102.7109375\n",
      "\tRelative L2 error (test) = 302.0089864730835\n",
      "Iteration Number = 370\n",
      "\tIC Loss = 4.752924919128418\n",
      "\tBC Loss = 63.670265197753906\n",
      "\tPhysics Loss = 18213.216796875\n",
      "\tTraining Loss = 18281.640625\n",
      "\tRelative L2 error (test) = 422.98145294189453\n",
      "Iteration Number = 380\n",
      "\tIC Loss = 4.939449310302734\n",
      "\tBC Loss = 68.78002166748047\n",
      "\tPhysics Loss = 15937.8984375\n",
      "\tTraining Loss = 16011.6181640625\n",
      "\tRelative L2 error (test) = 396.2805986404419\n",
      "Iteration Number = 390\n",
      "\tIC Loss = 4.548062324523926\n",
      "\tBC Loss = 54.501834869384766\n",
      "\tPhysics Loss = 13822.5361328125\n",
      "\tTraining Loss = 13881.5859375\n",
      "\tRelative L2 error (test) = 325.6779909133911\n",
      "Iteration Number = 400\n",
      "\tIC Loss = 4.193028926849365\n",
      "\tBC Loss = 53.03020477294922\n",
      "\tPhysics Loss = 10060.7998046875\n",
      "\tTraining Loss = 10118.0234375\n",
      "\tRelative L2 error (test) = 327.9397964477539\n",
      "Iteration Number = 410\n",
      "\tIC Loss = 3.874753713607788\n",
      "\tBC Loss = 54.7193603515625\n",
      "\tPhysics Loss = 7366.96142578125\n",
      "\tTraining Loss = 7425.5556640625\n",
      "\tRelative L2 error (test) = 379.8252820968628\n",
      "Iteration Number = 420\n",
      "\tIC Loss = 4.446844577789307\n",
      "\tBC Loss = 53.35930633544922\n",
      "\tPhysics Loss = 6418.4736328125\n",
      "\tTraining Loss = 6476.27978515625\n",
      "\tRelative L2 error (test) = 298.514986038208\n",
      "Iteration Number = 430\n",
      "\tIC Loss = 2.9877429008483887\n",
      "\tBC Loss = 39.902034759521484\n",
      "\tPhysics Loss = 3136.058837890625\n",
      "\tTraining Loss = 3178.94873046875\n",
      "\tRelative L2 error (test) = 323.6964464187622\n",
      "Iteration Number = 440\n",
      "\tIC Loss = 3.3687870502471924\n",
      "\tBC Loss = 47.67599868774414\n",
      "\tPhysics Loss = 2792.2060546875\n",
      "\tTraining Loss = 2843.250732421875\n",
      "\tRelative L2 error (test) = 304.43694591522217\n",
      "Iteration Number = 450\n",
      "\tIC Loss = 1.932906150817871\n",
      "\tBC Loss = 29.29632568359375\n",
      "\tPhysics Loss = 2126.42919921875\n",
      "\tTraining Loss = 2157.658447265625\n",
      "\tRelative L2 error (test) = 330.8685302734375\n",
      "Iteration Number = 460\n",
      "\tIC Loss = 1.4891437292099\n",
      "\tBC Loss = 21.87411880493164\n",
      "\tPhysics Loss = 1532.1436767578125\n",
      "\tTraining Loss = 1555.5069580078125\n",
      "\tRelative L2 error (test) = 213.64235877990723\n",
      "Iteration Number = 470\n",
      "\tIC Loss = 1.141794204711914\n",
      "\tBC Loss = 15.450740814208984\n",
      "\tPhysics Loss = 1483.174560546875\n",
      "\tTraining Loss = 1499.76708984375\n",
      "\tRelative L2 error (test) = 195.14213800430298\n",
      "Iteration Number = 480\n",
      "\tIC Loss = 1.0062273740768433\n",
      "\tBC Loss = 13.871877670288086\n",
      "\tPhysics Loss = 1352.467529296875\n",
      "\tTraining Loss = 1367.3455810546875\n",
      "\tRelative L2 error (test) = 237.4415159225464\n",
      "Iteration Number = 490\n",
      "\tIC Loss = 0.7367243766784668\n",
      "\tBC Loss = 11.522472381591797\n",
      "\tPhysics Loss = 956.2362060546875\n",
      "\tTraining Loss = 968.4954223632812\n",
      "\tRelative L2 error (test) = 176.07791423797607\n",
      "Iteration Number = 500\n",
      "\tIC Loss = 0.4941962957382202\n",
      "\tBC Loss = 6.993548393249512\n",
      "\tPhysics Loss = 846.8984375\n",
      "\tTraining Loss = 854.3861694335938\n",
      "\tRelative L2 error (test) = 162.88093328475952\n",
      "Iteration Number = 510\n",
      "\tIC Loss = 0.7464041709899902\n",
      "\tBC Loss = 9.832601547241211\n",
      "\tPhysics Loss = 807.829833984375\n",
      "\tTraining Loss = 818.4088134765625\n",
      "\tRelative L2 error (test) = 176.6731858253479\n",
      "Iteration Number = 520\n",
      "\tIC Loss = 0.8603142499923706\n",
      "\tBC Loss = 11.125469207763672\n",
      "\tPhysics Loss = 590.8509521484375\n",
      "\tTraining Loss = 602.8367309570312\n",
      "\tRelative L2 error (test) = 212.8037691116333\n",
      "Iteration Number = 530\n",
      "\tIC Loss = 0.8301714062690735\n",
      "\tBC Loss = 8.236255645751953\n",
      "\tPhysics Loss = 498.99041748046875\n",
      "\tTraining Loss = 508.0568542480469\n",
      "\tRelative L2 error (test) = 173.22134971618652\n",
      "Iteration Number = 540\n",
      "\tIC Loss = 0.7776778936386108\n",
      "\tBC Loss = 8.132803916931152\n",
      "\tPhysics Loss = 429.4920349121094\n",
      "\tTraining Loss = 438.40252685546875\n",
      "\tRelative L2 error (test) = 164.9608016014099\n",
      "Iteration Number = 550\n",
      "\tIC Loss = 0.6841030716896057\n",
      "\tBC Loss = 8.174239158630371\n",
      "\tPhysics Loss = 498.4867858886719\n",
      "\tTraining Loss = 507.3451232910156\n",
      "\tRelative L2 error (test) = 152.17761993408203\n",
      "Iteration Number = 560\n",
      "\tIC Loss = 0.7410706281661987\n",
      "\tBC Loss = 8.952467918395996\n",
      "\tPhysics Loss = 437.5269775390625\n",
      "\tTraining Loss = 447.22052001953125\n",
      "\tRelative L2 error (test) = 174.39130544662476\n",
      "Iteration Number = 570\n",
      "\tIC Loss = 0.8776964545249939\n",
      "\tBC Loss = 10.808075904846191\n",
      "\tPhysics Loss = 319.70880126953125\n",
      "\tTraining Loss = 331.3945617675781\n",
      "\tRelative L2 error (test) = 133.42759609222412\n",
      "Iteration Number = 580\n",
      "\tIC Loss = 1.0315769910812378\n",
      "\tBC Loss = 12.783060073852539\n",
      "\tPhysics Loss = 324.30438232421875\n",
      "\tTraining Loss = 338.1190185546875\n",
      "\tRelative L2 error (test) = 194.59266662597656\n",
      "Iteration Number = 590\n",
      "\tIC Loss = 0.9660419225692749\n",
      "\tBC Loss = 12.359355926513672\n",
      "\tPhysics Loss = 324.4887390136719\n",
      "\tTraining Loss = 337.81414794921875\n",
      "\tRelative L2 error (test) = 133.45619440078735\n",
      "Iteration Number = 600\n",
      "\tIC Loss = 0.8900982737541199\n",
      "\tBC Loss = 12.06747055053711\n",
      "\tPhysics Loss = 383.29119873046875\n",
      "\tTraining Loss = 396.248779296875\n",
      "\tRelative L2 error (test) = 133.77116918563843\n",
      "Iteration Number = 610\n",
      "\tIC Loss = 1.222153902053833\n",
      "\tBC Loss = 17.58769416809082\n",
      "\tPhysics Loss = 396.88946533203125\n",
      "\tTraining Loss = 415.6993103027344\n",
      "\tRelative L2 error (test) = 156.39183521270752\n",
      "Iteration Number = 620\n",
      "\tIC Loss = 1.5117158889770508\n",
      "\tBC Loss = 22.755189895629883\n",
      "\tPhysics Loss = 255.29757690429688\n",
      "\tTraining Loss = 279.5644836425781\n",
      "\tRelative L2 error (test) = 212.10527420043945\n",
      "Iteration Number = 630\n",
      "\tIC Loss = 1.8516243696212769\n",
      "\tBC Loss = 26.79640007019043\n",
      "\tPhysics Loss = 267.84881591796875\n",
      "\tTraining Loss = 296.496826171875\n",
      "\tRelative L2 error (test) = 214.93024826049805\n",
      "Iteration Number = 640\n",
      "\tIC Loss = 1.249794363975525\n",
      "\tBC Loss = 18.535736083984375\n",
      "\tPhysics Loss = 362.345703125\n",
      "\tTraining Loss = 382.1312255859375\n",
      "\tRelative L2 error (test) = 194.59906816482544\n",
      "Iteration Number = 650\n",
      "\tIC Loss = 0.9734936356544495\n",
      "\tBC Loss = 14.121131896972656\n",
      "\tPhysics Loss = 360.62591552734375\n",
      "\tTraining Loss = 375.7205505371094\n",
      "\tRelative L2 error (test) = 181.94643259048462\n",
      "Iteration Number = 660\n",
      "\tIC Loss = 0.6482346057891846\n",
      "\tBC Loss = 6.682154655456543\n",
      "\tPhysics Loss = 312.9894714355469\n",
      "\tTraining Loss = 320.3198547363281\n",
      "\tRelative L2 error (test) = 125.51735639572144\n",
      "Iteration Number = 670\n",
      "\tIC Loss = 0.3789816200733185\n",
      "\tBC Loss = 4.718526363372803\n",
      "\tPhysics Loss = 313.6690673828125\n",
      "\tTraining Loss = 318.7665710449219\n",
      "\tRelative L2 error (test) = 118.66191625595093\n",
      "Iteration Number = 680\n",
      "\tIC Loss = 0.38609737157821655\n",
      "\tBC Loss = 4.924472808837891\n",
      "\tPhysics Loss = 241.8020782470703\n",
      "\tTraining Loss = 247.11265563964844\n",
      "\tRelative L2 error (test) = 129.4467806816101\n",
      "Iteration Number = 690\n",
      "\tIC Loss = 0.23424050211906433\n",
      "\tBC Loss = 2.5149364471435547\n",
      "\tPhysics Loss = 291.3558349609375\n",
      "\tTraining Loss = 294.1050109863281\n",
      "\tRelative L2 error (test) = 95.05401849746704\n",
      "Iteration Number = 700\n",
      "\tIC Loss = 0.27574223279953003\n",
      "\tBC Loss = 3.915971517562866\n",
      "\tPhysics Loss = 223.32244873046875\n",
      "\tTraining Loss = 227.51416015625\n",
      "\tRelative L2 error (test) = 114.86377716064453\n",
      "Iteration Number = 710\n",
      "\tIC Loss = 0.31067198514938354\n",
      "\tBC Loss = 3.0264892578125\n",
      "\tPhysics Loss = 400.8623962402344\n",
      "\tTraining Loss = 404.1995544433594\n",
      "\tRelative L2 error (test) = 112.08070516586304\n",
      "Iteration Number = 720\n",
      "\tIC Loss = 0.19571365416049957\n",
      "\tBC Loss = 6.031892776489258\n",
      "\tPhysics Loss = 463.2640075683594\n",
      "\tTraining Loss = 469.4916076660156\n",
      "\tRelative L2 error (test) = 113.86771202087402\n",
      "Iteration Number = 730\n",
      "\tIC Loss = 0.4943198263645172\n",
      "\tBC Loss = 4.575027942657471\n",
      "\tPhysics Loss = 650.5740966796875\n",
      "\tTraining Loss = 655.6434326171875\n",
      "\tRelative L2 error (test) = 101.45999193191528\n",
      "Iteration Number = 740\n",
      "\tIC Loss = 0.20179450511932373\n",
      "\tBC Loss = 7.037574291229248\n",
      "\tPhysics Loss = 586.171142578125\n",
      "\tTraining Loss = 593.4105224609375\n",
      "\tRelative L2 error (test) = 113.37074041366577\n",
      "Iteration Number = 750\n",
      "\tIC Loss = 0.43402037024497986\n",
      "\tBC Loss = 3.9606432914733887\n",
      "\tPhysics Loss = 694.858642578125\n",
      "\tTraining Loss = 699.2532958984375\n",
      "\tRelative L2 error (test) = 113.08263540267944\n",
      "Iteration Number = 760\n",
      "\tIC Loss = 0.1622755378484726\n",
      "\tBC Loss = 4.439528942108154\n",
      "\tPhysics Loss = 300.89263916015625\n",
      "\tTraining Loss = 305.49444580078125\n",
      "\tRelative L2 error (test) = 115.68890810012817\n",
      "Iteration Number = 770\n",
      "\tIC Loss = 0.28930461406707764\n",
      "\tBC Loss = 5.857231616973877\n",
      "\tPhysics Loss = 251.24923706054688\n",
      "\tTraining Loss = 257.3957824707031\n",
      "\tRelative L2 error (test) = 124.4570255279541\n",
      "Iteration Number = 780\n",
      "\tIC Loss = 0.3148118257522583\n",
      "\tBC Loss = 5.838131904602051\n",
      "\tPhysics Loss = 224.37356567382812\n",
      "\tTraining Loss = 230.52650451660156\n",
      "\tRelative L2 error (test) = 124.63828325271606\n",
      "Iteration Number = 790\n",
      "\tIC Loss = 0.6212544441223145\n",
      "\tBC Loss = 8.626331329345703\n",
      "\tPhysics Loss = 288.3166809082031\n",
      "\tTraining Loss = 297.56427001953125\n",
      "\tRelative L2 error (test) = 128.5409688949585\n",
      "Iteration Number = 800\n",
      "\tIC Loss = 0.29574960470199585\n",
      "\tBC Loss = 7.645435810089111\n",
      "\tPhysics Loss = 197.9796600341797\n",
      "\tTraining Loss = 205.9208526611328\n",
      "\tRelative L2 error (test) = 120.91372013092041\n",
      "Iteration Number = 810\n",
      "\tIC Loss = 0.3277886211872101\n",
      "\tBC Loss = 7.716475486755371\n",
      "\tPhysics Loss = 235.031982421875\n",
      "\tTraining Loss = 243.0762481689453\n",
      "\tRelative L2 error (test) = 115.54346084594727\n",
      "Iteration Number = 820\n",
      "\tIC Loss = 0.34655359387397766\n",
      "\tBC Loss = 7.08793306350708\n",
      "\tPhysics Loss = 243.9864044189453\n",
      "\tTraining Loss = 251.4208984375\n",
      "\tRelative L2 error (test) = 114.05051946640015\n",
      "Iteration Number = 830\n",
      "\tIC Loss = 0.24955877661705017\n",
      "\tBC Loss = 6.193397045135498\n",
      "\tPhysics Loss = 238.2931671142578\n",
      "\tTraining Loss = 244.7361297607422\n",
      "\tRelative L2 error (test) = 107.34028816223145\n",
      "Iteration Number = 840\n",
      "\tIC Loss = 0.3220839500427246\n",
      "\tBC Loss = 7.036437034606934\n",
      "\tPhysics Loss = 270.37091064453125\n",
      "\tTraining Loss = 277.72943115234375\n",
      "\tRelative L2 error (test) = 116.98201894760132\n",
      "Iteration Number = 850\n",
      "\tIC Loss = 0.3509960174560547\n",
      "\tBC Loss = 8.615981101989746\n",
      "\tPhysics Loss = 299.9997863769531\n",
      "\tTraining Loss = 308.9667663574219\n",
      "\tRelative L2 error (test) = 132.8543186187744\n",
      "Iteration Number = 860\n",
      "\tIC Loss = 0.32937487959861755\n",
      "\tBC Loss = 8.191402435302734\n",
      "\tPhysics Loss = 236.29978942871094\n",
      "\tTraining Loss = 244.82057189941406\n",
      "\tRelative L2 error (test) = 129.42934036254883\n",
      "Iteration Number = 870\n",
      "\tIC Loss = 0.34926268458366394\n",
      "\tBC Loss = 5.734352111816406\n",
      "\tPhysics Loss = 292.389892578125\n",
      "\tTraining Loss = 298.4735107421875\n",
      "\tRelative L2 error (test) = 120.49399614334106\n",
      "Iteration Number = 880\n",
      "\tIC Loss = 1.232295036315918\n",
      "\tBC Loss = 13.708137512207031\n",
      "\tPhysics Loss = 2383.271728515625\n",
      "\tTraining Loss = 2398.212158203125\n",
      "\tRelative L2 error (test) = 141.20237827301025\n",
      "Iteration Number = 890\n",
      "\tIC Loss = 0.474422812461853\n",
      "\tBC Loss = 4.592380046844482\n",
      "\tPhysics Loss = 996.7241821289062\n",
      "\tTraining Loss = 1001.791015625\n",
      "\tRelative L2 error (test) = 95.05355954170227\n",
      "Iteration Number = 900\n",
      "\tIC Loss = 0.14149793982505798\n",
      "\tBC Loss = 2.4800376892089844\n",
      "\tPhysics Loss = 324.4499206542969\n",
      "\tTraining Loss = 327.0714416503906\n",
      "\tRelative L2 error (test) = 115.43859243392944\n",
      "Iteration Number = 910\n",
      "\tIC Loss = 0.4174146056175232\n",
      "\tBC Loss = 8.240485191345215\n",
      "\tPhysics Loss = 528.62939453125\n",
      "\tTraining Loss = 537.2872924804688\n",
      "\tRelative L2 error (test) = 153.23115587234497\n",
      "Iteration Number = 920\n",
      "\tIC Loss = 0.1685989797115326\n",
      "\tBC Loss = 6.650657653808594\n",
      "\tPhysics Loss = 343.9195251464844\n",
      "\tTraining Loss = 350.73876953125\n",
      "\tRelative L2 error (test) = 134.34929847717285\n",
      "Iteration Number = 930\n",
      "\tIC Loss = 0.381915420293808\n",
      "\tBC Loss = 6.52476167678833\n",
      "\tPhysics Loss = 262.10040283203125\n",
      "\tTraining Loss = 269.007080078125\n",
      "\tRelative L2 error (test) = 127.9593825340271\n",
      "Iteration Number = 940\n",
      "\tIC Loss = 0.22590553760528564\n",
      "\tBC Loss = 4.692750453948975\n",
      "\tPhysics Loss = 272.36328125\n",
      "\tTraining Loss = 277.2819519042969\n",
      "\tRelative L2 error (test) = 126.78678035736084\n",
      "Iteration Number = 950\n",
      "\tIC Loss = 0.0672166720032692\n",
      "\tBC Loss = 2.5126073360443115\n",
      "\tPhysics Loss = 305.53839111328125\n",
      "\tTraining Loss = 308.11822509765625\n",
      "\tRelative L2 error (test) = 100.9140133857727\n",
      "Iteration Number = 960\n",
      "\tIC Loss = 0.20517030358314514\n",
      "\tBC Loss = 2.4422082901000977\n",
      "\tPhysics Loss = 267.9516296386719\n",
      "\tTraining Loss = 270.5989990234375\n",
      "\tRelative L2 error (test) = 118.49228143692017\n",
      "Iteration Number = 970\n",
      "\tIC Loss = 0.146121546626091\n",
      "\tBC Loss = 1.9093486070632935\n",
      "\tPhysics Loss = 258.7697448730469\n",
      "\tTraining Loss = 260.8252258300781\n",
      "\tRelative L2 error (test) = 98.20417165756226\n",
      "Iteration Number = 980\n",
      "\tIC Loss = 0.3760570287704468\n",
      "\tBC Loss = 3.848874807357788\n",
      "\tPhysics Loss = 266.6337585449219\n",
      "\tTraining Loss = 270.85870361328125\n",
      "\tRelative L2 error (test) = 120.97076177597046\n",
      "Iteration Number = 990\n",
      "\tIC Loss = 0.606253981590271\n",
      "\tBC Loss = 7.1867523193359375\n",
      "\tPhysics Loss = 231.6094512939453\n",
      "\tTraining Loss = 239.40245056152344\n",
      "\tRelative L2 error (test) = 114.54488039016724\n",
      "Iteration Number = 1000\n",
      "\tIC Loss = 0.9551272392272949\n",
      "\tBC Loss = 10.538870811462402\n",
      "\tPhysics Loss = 247.11630249023438\n",
      "\tTraining Loss = 258.61029052734375\n",
      "\tRelative L2 error (test) = 141.20724201202393\n",
      "Iteration Number = 1010\n",
      "\tIC Loss = 0.9949576258659363\n",
      "\tBC Loss = 10.797430992126465\n",
      "\tPhysics Loss = 235.72557067871094\n",
      "\tTraining Loss = 247.51795959472656\n",
      "\tRelative L2 error (test) = 149.32442903518677\n",
      "Iteration Number = 1020\n",
      "\tIC Loss = 1.2471798658370972\n",
      "\tBC Loss = 15.689887046813965\n",
      "\tPhysics Loss = 204.08477783203125\n",
      "\tTraining Loss = 221.0218505859375\n",
      "\tRelative L2 error (test) = 127.76775360107422\n",
      "Iteration Number = 1030\n",
      "\tIC Loss = 1.1254587173461914\n",
      "\tBC Loss = 11.838226318359375\n",
      "\tPhysics Loss = 187.11032104492188\n",
      "\tTraining Loss = 200.07400512695312\n",
      "\tRelative L2 error (test) = 153.15669775009155\n",
      "Iteration Number = 1040\n",
      "\tIC Loss = 1.4484914541244507\n",
      "\tBC Loss = 17.75830841064453\n",
      "\tPhysics Loss = 181.35943603515625\n",
      "\tTraining Loss = 200.5662384033203\n",
      "\tRelative L2 error (test) = 166.8249487876892\n",
      "Iteration Number = 1050\n",
      "\tIC Loss = 1.220634937286377\n",
      "\tBC Loss = 13.969762802124023\n",
      "\tPhysics Loss = 201.60154724121094\n",
      "\tTraining Loss = 216.7919464111328\n",
      "\tRelative L2 error (test) = 133.3337903022766\n",
      "Iteration Number = 1060\n",
      "\tIC Loss = 1.3171336650848389\n",
      "\tBC Loss = 13.429574012756348\n",
      "\tPhysics Loss = 908.0150756835938\n",
      "\tTraining Loss = 922.7617797851562\n",
      "\tRelative L2 error (test) = 174.15378093719482\n",
      "Iteration Number = 1070\n",
      "\tIC Loss = 0.6855558156967163\n",
      "\tBC Loss = 6.382470607757568\n",
      "\tPhysics Loss = 490.0072021484375\n",
      "\tTraining Loss = 497.0752258300781\n",
      "\tRelative L2 error (test) = 97.28561043739319\n",
      "Iteration Number = 1080\n",
      "\tIC Loss = 1.073340654373169\n",
      "\tBC Loss = 8.916458129882812\n",
      "\tPhysics Loss = 182.69778442382812\n",
      "\tTraining Loss = 192.6875762939453\n",
      "\tRelative L2 error (test) = 131.4048409461975\n",
      "Iteration Number = 1090\n",
      "\tIC Loss = 0.6009460091590881\n",
      "\tBC Loss = 5.397204399108887\n",
      "\tPhysics Loss = 173.60531616210938\n",
      "\tTraining Loss = 179.6034698486328\n",
      "\tRelative L2 error (test) = 90.9814178943634\n",
      "Iteration Number = 1100\n",
      "\tIC Loss = 0.5631096959114075\n",
      "\tBC Loss = 4.094057083129883\n",
      "\tPhysics Loss = 209.8646240234375\n",
      "\tTraining Loss = 214.52178955078125\n",
      "\tRelative L2 error (test) = 102.39189863204956\n",
      "Iteration Number = 1110\n",
      "\tIC Loss = 0.5764089226722717\n",
      "\tBC Loss = 7.929634094238281\n",
      "\tPhysics Loss = 198.2516632080078\n",
      "\tTraining Loss = 206.75770568847656\n",
      "\tRelative L2 error (test) = 102.40890979766846\n",
      "Iteration Number = 1120\n",
      "\tIC Loss = 0.33595556020736694\n",
      "\tBC Loss = 4.130378723144531\n",
      "\tPhysics Loss = 754.303955078125\n",
      "\tTraining Loss = 758.770263671875\n",
      "\tRelative L2 error (test) = 123.70014190673828\n",
      "Iteration Number = 1130\n",
      "\tIC Loss = 0.3931503891944885\n",
      "\tBC Loss = 5.169249534606934\n",
      "\tPhysics Loss = 374.8182373046875\n",
      "\tTraining Loss = 380.3806457519531\n",
      "\tRelative L2 error (test) = 107.8932523727417\n",
      "Iteration Number = 1140\n",
      "\tIC Loss = 0.42356669902801514\n",
      "\tBC Loss = 4.14631462097168\n",
      "\tPhysics Loss = 224.3671112060547\n",
      "\tTraining Loss = 228.93699645996094\n",
      "\tRelative L2 error (test) = 92.65850186347961\n",
      "Iteration Number = 1150\n",
      "\tIC Loss = 0.39516133069992065\n",
      "\tBC Loss = 4.244563579559326\n",
      "\tPhysics Loss = 194.20777893066406\n",
      "\tTraining Loss = 198.84750366210938\n",
      "\tRelative L2 error (test) = 88.5169506072998\n",
      "Iteration Number = 1160\n",
      "\tIC Loss = 0.23663796484470367\n",
      "\tBC Loss = 4.2034101486206055\n",
      "\tPhysics Loss = 168.87692260742188\n",
      "\tTraining Loss = 173.3169708251953\n",
      "\tRelative L2 error (test) = 92.20879673957825\n",
      "Iteration Number = 1170\n",
      "\tIC Loss = 0.2509869933128357\n",
      "\tBC Loss = 3.9004158973693848\n",
      "\tPhysics Loss = 144.18284606933594\n",
      "\tTraining Loss = 148.33424377441406\n",
      "\tRelative L2 error (test) = 104.73182201385498\n",
      "Iteration Number = 1180\n",
      "\tIC Loss = 0.42325782775878906\n",
      "\tBC Loss = 7.438437461853027\n",
      "\tPhysics Loss = 298.720458984375\n",
      "\tTraining Loss = 306.5821533203125\n",
      "\tRelative L2 error (test) = 118.72844696044922\n",
      "Iteration Number = 1190\n",
      "\tIC Loss = 0.3299236595630646\n",
      "\tBC Loss = 6.750854969024658\n",
      "\tPhysics Loss = 2439.267578125\n",
      "\tTraining Loss = 2446.348388671875\n",
      "\tRelative L2 error (test) = 190.22753238677979\n",
      "Iteration Number = 1200\n",
      "\tIC Loss = 0.2227155715227127\n",
      "\tBC Loss = 5.263664245605469\n",
      "\tPhysics Loss = 996.1131591796875\n",
      "\tTraining Loss = 1001.5995483398438\n",
      "\tRelative L2 error (test) = 189.11153078079224\n",
      "Iteration Number = 1210\n",
      "\tIC Loss = 0.492849200963974\n",
      "\tBC Loss = 6.906524658203125\n",
      "\tPhysics Loss = 559.1822509765625\n",
      "\tTraining Loss = 566.5816040039062\n",
      "\tRelative L2 error (test) = 102.97397375106812\n",
      "Iteration Number = 1220\n",
      "\tIC Loss = 0.2812612056732178\n",
      "\tBC Loss = 6.25447416305542\n",
      "\tPhysics Loss = 297.397216796875\n",
      "\tTraining Loss = 303.9329528808594\n",
      "\tRelative L2 error (test) = 122.23798036575317\n",
      "Iteration Number = 1230\n",
      "\tIC Loss = 0.35881996154785156\n",
      "\tBC Loss = 7.360722064971924\n",
      "\tPhysics Loss = 155.60948181152344\n",
      "\tTraining Loss = 163.3290252685547\n",
      "\tRelative L2 error (test) = 145.7055687904358\n",
      "Iteration Number = 1240\n",
      "\tIC Loss = 0.6348321437835693\n",
      "\tBC Loss = 10.731876373291016\n",
      "\tPhysics Loss = 169.67897033691406\n",
      "\tTraining Loss = 181.04568481445312\n",
      "\tRelative L2 error (test) = 135.62183380126953\n",
      "Iteration Number = 1250\n",
      "\tIC Loss = 0.6382217407226562\n",
      "\tBC Loss = 11.00409984588623\n",
      "\tPhysics Loss = 138.30165100097656\n",
      "\tTraining Loss = 149.9439697265625\n",
      "\tRelative L2 error (test) = 134.07535552978516\n",
      "Iteration Number = 1260\n",
      "\tIC Loss = 0.29369765520095825\n",
      "\tBC Loss = 5.837469100952148\n",
      "\tPhysics Loss = 92.17799377441406\n",
      "\tTraining Loss = 98.30915832519531\n",
      "\tRelative L2 error (test) = 122.61766195297241\n",
      "Iteration Number = 1270\n",
      "\tIC Loss = 0.2084575593471527\n",
      "\tBC Loss = 4.338450908660889\n",
      "\tPhysics Loss = 110.89676666259766\n",
      "\tTraining Loss = 115.44367218017578\n",
      "\tRelative L2 error (test) = 100.0750184059143\n",
      "Iteration Number = 1280\n",
      "\tIC Loss = 0.1283242255449295\n",
      "\tBC Loss = 2.191014289855957\n",
      "\tPhysics Loss = 118.71595764160156\n",
      "\tTraining Loss = 121.03529357910156\n",
      "\tRelative L2 error (test) = 57.385969161987305\n",
      "Iteration Number = 1290\n",
      "\tIC Loss = 0.14285612106323242\n",
      "\tBC Loss = 3.3806538581848145\n",
      "\tPhysics Loss = 107.91014862060547\n",
      "\tTraining Loss = 111.43365478515625\n",
      "\tRelative L2 error (test) = 62.22148537635803\n",
      "Iteration Number = 1300\n",
      "\tIC Loss = 0.2783539891242981\n",
      "\tBC Loss = 5.087686538696289\n",
      "\tPhysics Loss = 96.93376922607422\n",
      "\tTraining Loss = 102.29981231689453\n",
      "\tRelative L2 error (test) = 57.296085357666016\n",
      "Iteration Number = 1310\n",
      "\tIC Loss = 0.3941778540611267\n",
      "\tBC Loss = 6.42205810546875\n",
      "\tPhysics Loss = 348.7462158203125\n",
      "\tTraining Loss = 355.56243896484375\n",
      "\tRelative L2 error (test) = 72.49222993850708\n",
      "Iteration Number = 1320\n",
      "\tIC Loss = 0.37282514572143555\n",
      "\tBC Loss = 6.722161769866943\n",
      "\tPhysics Loss = 684.4791870117188\n",
      "\tTraining Loss = 691.5741577148438\n",
      "\tRelative L2 error (test) = 87.88474798202515\n",
      "Iteration Number = 1330\n",
      "\tIC Loss = 0.30429062247276306\n",
      "\tBC Loss = 5.4262776374816895\n",
      "\tPhysics Loss = 158.4158477783203\n",
      "\tTraining Loss = 164.14642333984375\n",
      "\tRelative L2 error (test) = 59.763556718826294\n",
      "Iteration Number = 1340\n",
      "\tIC Loss = 0.30618807673454285\n",
      "\tBC Loss = 4.849880218505859\n",
      "\tPhysics Loss = 100.37113189697266\n",
      "\tTraining Loss = 105.5271987915039\n",
      "\tRelative L2 error (test) = 55.0409197807312\n",
      "Iteration Number = 1350\n",
      "\tIC Loss = 0.2895626425743103\n",
      "\tBC Loss = 4.517267227172852\n",
      "\tPhysics Loss = 111.24758911132812\n",
      "\tTraining Loss = 116.0544204711914\n",
      "\tRelative L2 error (test) = 69.05654072761536\n",
      "Iteration Number = 1360\n",
      "\tIC Loss = 0.2086661159992218\n",
      "\tBC Loss = 3.218863010406494\n",
      "\tPhysics Loss = 194.7784423828125\n",
      "\tTraining Loss = 198.2059783935547\n",
      "\tRelative L2 error (test) = 47.607505321502686\n",
      "Iteration Number = 1370\n",
      "\tIC Loss = 0.20092666149139404\n",
      "\tBC Loss = 3.0682625770568848\n",
      "\tPhysics Loss = 150.4936981201172\n",
      "\tTraining Loss = 153.7628936767578\n",
      "\tRelative L2 error (test) = 45.020389556884766\n",
      "Iteration Number = 1380\n",
      "\tIC Loss = 0.16341903805732727\n",
      "\tBC Loss = 2.2104313373565674\n",
      "\tPhysics Loss = 154.30133056640625\n",
      "\tTraining Loss = 156.67518615722656\n",
      "\tRelative L2 error (test) = 50.64721703529358\n",
      "Iteration Number = 1390\n",
      "\tIC Loss = 0.1444292664527893\n",
      "\tBC Loss = 2.3197383880615234\n",
      "\tPhysics Loss = 136.4329833984375\n",
      "\tTraining Loss = 138.89715576171875\n",
      "\tRelative L2 error (test) = 43.020108342170715\n",
      "Iteration Number = 1400\n",
      "\tIC Loss = 0.16625183820724487\n",
      "\tBC Loss = 2.7303528785705566\n",
      "\tPhysics Loss = 155.3595733642578\n",
      "\tTraining Loss = 158.2561798095703\n",
      "\tRelative L2 error (test) = 47.67087697982788\n",
      "Iteration Number = 1410\n",
      "\tIC Loss = 0.18301299214363098\n",
      "\tBC Loss = 2.32073974609375\n",
      "\tPhysics Loss = 143.43080139160156\n",
      "\tTraining Loss = 145.93455505371094\n",
      "\tRelative L2 error (test) = 45.860275626182556\n",
      "Iteration Number = 1420\n",
      "\tIC Loss = 0.24490340054035187\n",
      "\tBC Loss = 3.755739688873291\n",
      "\tPhysics Loss = 140.70619201660156\n",
      "\tTraining Loss = 144.7068328857422\n",
      "\tRelative L2 error (test) = 48.47252070903778\n",
      "Iteration Number = 1430\n",
      "\tIC Loss = 0.2487156242132187\n",
      "\tBC Loss = 3.7024898529052734\n",
      "\tPhysics Loss = 96.93086242675781\n",
      "\tTraining Loss = 100.88206481933594\n",
      "\tRelative L2 error (test) = 53.33554744720459\n",
      "Iteration Number = 1440\n",
      "\tIC Loss = 0.215035080909729\n",
      "\tBC Loss = 2.9791014194488525\n",
      "\tPhysics Loss = 106.2261962890625\n",
      "\tTraining Loss = 109.42033386230469\n",
      "\tRelative L2 error (test) = 48.14686477184296\n",
      "Iteration Number = 1450\n",
      "\tIC Loss = 0.16178561747074127\n",
      "\tBC Loss = 2.09696626663208\n",
      "\tPhysics Loss = 97.11144256591797\n",
      "\tTraining Loss = 99.37019348144531\n",
      "\tRelative L2 error (test) = 44.57572400569916\n",
      "Iteration Number = 1460\n",
      "\tIC Loss = 0.21764527261257172\n",
      "\tBC Loss = 2.356395721435547\n",
      "\tPhysics Loss = 131.95118713378906\n",
      "\tTraining Loss = 134.5252227783203\n",
      "\tRelative L2 error (test) = 40.92069864273071\n",
      "Iteration Number = 1470\n",
      "\tIC Loss = 0.2401314079761505\n",
      "\tBC Loss = 1.9986878633499146\n",
      "\tPhysics Loss = 572.5198364257812\n",
      "\tTraining Loss = 574.7586669921875\n",
      "\tRelative L2 error (test) = 61.2451434135437\n",
      "Iteration Number = 1480\n",
      "\tIC Loss = 0.16397014260292053\n",
      "\tBC Loss = 2.24172043800354\n",
      "\tPhysics Loss = 259.86627197265625\n",
      "\tTraining Loss = 262.27197265625\n",
      "\tRelative L2 error (test) = 58.5396945476532\n",
      "Iteration Number = 1490\n",
      "\tIC Loss = 0.2233792245388031\n",
      "\tBC Loss = 2.751345634460449\n",
      "\tPhysics Loss = 249.3301239013672\n",
      "\tTraining Loss = 252.3048553466797\n",
      "\tRelative L2 error (test) = 45.125359296798706\n",
      "Iteration Number = 1500\n",
      "\tIC Loss = 0.1458239108324051\n",
      "\tBC Loss = 1.793060541152954\n",
      "\tPhysics Loss = 236.10104370117188\n",
      "\tTraining Loss = 238.03993225097656\n",
      "\tRelative L2 error (test) = 50.15682578086853\n",
      "Iteration Number = 1510\n",
      "\tIC Loss = 0.13649329543113708\n",
      "\tBC Loss = 1.0801153182983398\n",
      "\tPhysics Loss = 161.4810791015625\n",
      "\tTraining Loss = 162.69769287109375\n",
      "\tRelative L2 error (test) = 42.83679127693176\n",
      "Iteration Number = 1520\n",
      "\tIC Loss = 0.06760691106319427\n",
      "\tBC Loss = 0.9853248596191406\n",
      "\tPhysics Loss = 519.6000366210938\n",
      "\tTraining Loss = 520.6529541015625\n",
      "\tRelative L2 error (test) = 50.02940893173218\n",
      "Iteration Number = 1530\n",
      "\tIC Loss = 0.05744342878460884\n",
      "\tBC Loss = 0.7299901247024536\n",
      "\tPhysics Loss = 203.42544555664062\n",
      "\tTraining Loss = 204.21287536621094\n",
      "\tRelative L2 error (test) = 52.06073522567749\n",
      "Iteration Number = 1540\n",
      "\tIC Loss = 0.1354268193244934\n",
      "\tBC Loss = 1.664383053779602\n",
      "\tPhysics Loss = 204.70176696777344\n",
      "\tTraining Loss = 206.50157165527344\n",
      "\tRelative L2 error (test) = 41.13565385341644\n",
      "Iteration Number = 1550\n",
      "\tIC Loss = 0.07306578010320663\n",
      "\tBC Loss = 1.1770249605178833\n",
      "\tPhysics Loss = 162.07591247558594\n",
      "\tTraining Loss = 163.3260040283203\n",
      "\tRelative L2 error (test) = 39.677366614341736\n",
      "Iteration Number = 1560\n",
      "\tIC Loss = 0.0907014012336731\n",
      "\tBC Loss = 1.0363103151321411\n",
      "\tPhysics Loss = 157.86688232421875\n",
      "\tTraining Loss = 158.993896484375\n",
      "\tRelative L2 error (test) = 42.9566353559494\n",
      "Iteration Number = 1570\n",
      "\tIC Loss = 0.08701921254396439\n",
      "\tBC Loss = 1.1068609952926636\n",
      "\tPhysics Loss = 152.19068908691406\n",
      "\tTraining Loss = 153.3845672607422\n",
      "\tRelative L2 error (test) = 42.48059093952179\n",
      "Iteration Number = 1580\n",
      "\tIC Loss = 0.09148050844669342\n",
      "\tBC Loss = 1.0071507692337036\n",
      "\tPhysics Loss = 115.82858276367188\n",
      "\tTraining Loss = 116.92721557617188\n",
      "\tRelative L2 error (test) = 40.83191156387329\n",
      "Iteration Number = 1590\n",
      "\tIC Loss = 0.06956478953361511\n",
      "\tBC Loss = 0.9773608446121216\n",
      "\tPhysics Loss = 116.31867980957031\n",
      "\tTraining Loss = 117.36560821533203\n",
      "\tRelative L2 error (test) = 40.75123369693756\n",
      "Iteration Number = 1600\n",
      "\tIC Loss = 0.07334621250629425\n",
      "\tBC Loss = 0.8061811923980713\n",
      "\tPhysics Loss = 138.6912841796875\n",
      "\tTraining Loss = 139.57081604003906\n",
      "\tRelative L2 error (test) = 36.24107539653778\n",
      "Iteration Number = 1610\n",
      "\tIC Loss = 0.07287081331014633\n",
      "\tBC Loss = 1.3667658567428589\n",
      "\tPhysics Loss = 62.43264389038086\n",
      "\tTraining Loss = 63.87228012084961\n",
      "\tRelative L2 error (test) = 45.53243815898895\n",
      "Iteration Number = 1620\n",
      "\tIC Loss = 0.13950756192207336\n",
      "\tBC Loss = 1.8336551189422607\n",
      "\tPhysics Loss = 1087.923095703125\n",
      "\tTraining Loss = 1089.896240234375\n",
      "\tRelative L2 error (test) = 54.35447096824646\n",
      "Iteration Number = 1630\n",
      "\tIC Loss = 0.09753668308258057\n",
      "\tBC Loss = 2.030322790145874\n",
      "\tPhysics Loss = 357.6218566894531\n",
      "\tTraining Loss = 359.7497253417969\n",
      "\tRelative L2 error (test) = 93.23155283927917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader \n",
    "import numpy as np\n",
    "from scipy import io\n",
    "import matplotlib.pyplot as plt\n",
    "import argparse\n",
    "import os\n",
    "import copy\n",
    "import time\n",
    "\n",
    "\n",
    "def inital_soln(xt_grid):\n",
    "    x_grid = xt_grid[:,:-1]\n",
    "    t = xt_grid[:,-1].reshape(-1,1)\n",
    "\n",
    "    d = x_grid.shape[1]\n",
    "    u = torch.cos((40/d)*x_grid.sum(axis = 1)).reshape(-1,1)\n",
    "\n",
    "    return u.flatten()\n",
    "\n",
    "\n",
    "def exact_soln(xt_grid):\n",
    "    \n",
    "    x_grid = xt_grid[:,:-1]\n",
    "    t = xt_grid[:,-1].reshape(-1,1)\n",
    "\n",
    "    d = x_grid.shape[1]\n",
    "    u = torch.cos((40/d)*x_grid.sum(axis = 1)).reshape(-1,1)*torch.exp(-t)\n",
    "    \n",
    "    return u.flatten()\n",
    "\n",
    "\n",
    "def f_function(xt_grid):\n",
    "    \n",
    "    x_grid = xt_grid[:,:-1]\n",
    "    t = xt_grid[:,-1].reshape(-1,1)\n",
    "\n",
    "    d = x_grid.shape[1]\n",
    "    u = (1600/d-1)*torch.cos((40/d)*x_grid.sum(axis = 1)).reshape(-1,1)*torch.exp(-t)\n",
    "    \n",
    "    return u.flatten()\n",
    "\n",
    "\n",
    "def stacked_grid(*args):\n",
    "    # Generate meshgrid dynamically\n",
    "    grids = torch.meshgrid(*args, indexing='ij')\n",
    "    # Flatten and stack the grids into a single tensor\n",
    "    stacked = torch.hstack([g.flatten()[:, None] for g in grids])\n",
    "    return stacked.float()\n",
    "\n",
    "\n",
    "def cal_domain_grad(model, XTGrid, device):\n",
    "    Loss = torch.nn.MSELoss(reduction='mean')\n",
    "\n",
    "    XTGrid = XTGrid.requires_grad_(True).to(device)\n",
    "    u = model.forward(XTGrid)\n",
    "    \n",
    "    # Compute first derivatives\n",
    "    u_grad = torch.autograd.grad(outputs=u, \n",
    "                                 inputs=XTGrid, \n",
    "                                 grad_outputs=torch.ones_like(u).to(device), \n",
    "                                 create_graph=True, \n",
    "                                 allow_unused=True)[0]\n",
    "\n",
    "    # Compute second derivatives for each spatial dimension\n",
    "    u_laplacian = torch.zeros_like(u[:,0], device=device)\n",
    "    for d in range(XTGrid.shape[1] - 1):  # Assuming last dim is time\n",
    "        ux = u_grad[:, d]\n",
    "        uxx = torch.autograd.grad(outputs=ux, \n",
    "                                  inputs=XTGrid, \n",
    "                                  grad_outputs=torch.ones_like(ux).to(device), \n",
    "                                  create_graph=True, \n",
    "                                  allow_unused=True)[0][:, d]\n",
    "        u_laplacian += uxx\n",
    "    f = f_function(XTGrid)\n",
    "    # Time derivative\n",
    "    ut = u_grad[:, -1]  # Assuming last dimension is time\n",
    "\n",
    "    # PDE residual loss (generalized heat/diffusion equation)\n",
    "    lossf = Loss(ut-u_laplacian, f.flatten())\n",
    "\n",
    "    # Compute gradient of loss function w.r.t. inputs (for adaptivity or error estimation)\n",
    "    grad = torch.autograd.grad(outputs=lossf, \n",
    "                               inputs=XTGrid, \n",
    "                               grad_outputs=torch.ones_like(lossf).to(device),\n",
    "                               create_graph=True,\n",
    "                               allow_unused=True)[0]\n",
    "    \n",
    "    return grad\n",
    "\n",
    "\n",
    "class RADSampler():\n",
    "    def __init__(self, Nf, device, k, c, dim = 1):    \n",
    "        self.device = device\n",
    "        self.k = k\n",
    "        self.c = c\n",
    "        self.Nf = Nf\n",
    "        self.dense_Nf = Nf*1\n",
    "        self.dim = dim\n",
    "        \n",
    "    def update(self, model):\n",
    "        \n",
    "        x_new = torch.zeros(self.dense_Nf, self.dim, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        t_new = torch.zeros(self.dense_Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "        \n",
    "        XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "        XTGrid = XTGrid.to(self.device)\n",
    "\n",
    "        XTGrid = XTGrid.requires_grad_(True).to(self.device)\n",
    "        \n",
    "        u = model.forward(XTGrid)\n",
    "        \n",
    "        u_grad = torch.autograd.grad(outputs=u, \n",
    "                                     inputs=XTGrid, \n",
    "                                     grad_outputs=torch.ones_like(u).to(device), \n",
    "                                     create_graph=True, \n",
    "                                     allow_unused=True)[0]\n",
    "    \n",
    "        u_laplacian = torch.zeros_like(u[:,0], device=device)\n",
    "        for d in range(XTGrid.shape[1] - 1):  # Assuming last dim is time\n",
    "            ux = u_grad[:, d]\n",
    "            uxx = torch.autograd.grad(outputs=ux, \n",
    "                                      inputs=XTGrid, \n",
    "                                      grad_outputs=torch.ones_like(ux).to(device), \n",
    "                                      create_graph=True, \n",
    "                                      allow_unused=True)[0][:, d]\n",
    "            u_laplacian += uxx\n",
    "            \n",
    "        f = f_function(XTGrid)\n",
    "        ut = u_grad[:, -1]  \n",
    "    \n",
    "        err = torch.abs((ut-u_laplacian-f.flatten()))\n",
    "        err = (err**self.k)/((err**self.k).mean())+self.c\n",
    "        err_norm = err/(err.sum())\n",
    "        \n",
    "        indice = torch.multinomial(err_norm, self.Nf, replacement = True)\n",
    "        XTGrid = XTGrid[indice]\n",
    "        self.XTGrid = XTGrid\n",
    "        \n",
    "class LASSampler():\n",
    "    def __init__(self, Nf, fixed_uniform, device, L_iter = 1, beta = 0.2, tau = 0.002):\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.cnt = 0\n",
    "        self.beta = beta\n",
    "        self.tau = tau\n",
    "        self.L_iter = L_iter\n",
    "        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "\n",
    "    def update(self, phy_lf, model):\n",
    "\n",
    "        x_data = self.XTGrid\n",
    "        samples = x_data.clone().detach().requires_grad_(True)\n",
    "        \n",
    "        for t in range(1, self.L_iter + 1):\n",
    "            grad = phy_lf(model, samples, self.device)\n",
    "            scaler = torch.sqrt(torch.sum((grad+1e-16)**2, axis = 1)).reshape(-1,1)\n",
    "            grad = grad/scaler\n",
    "\n",
    "            with torch.no_grad():\n",
    "                samples = samples + self.tau * grad + self.beta*torch.sqrt(torch.tensor(2 * self.tau, device=self.device)) * torch.randn(samples.shape, device=self.device)\n",
    "                for i in range(samples.shape[1]):\n",
    "                    if i < int(samples.shape[1]-1):\n",
    "                        samples[:, i] = torch.clamp(samples[:, i], min=-1, max=1)  # x축 클램핑\n",
    "                    else:\n",
    "                        samples[:, i] = torch.clamp(samples[:, i], min=0, max=1)   # t축 클램핑\n",
    "                        \n",
    "            samples = samples.clone().detach().requires_grad_(True)\n",
    "        self.XTGrid = samples.detach()\n",
    "\n",
    "class L_INFSampler():\n",
    "    def __init__(self, Nf, device, step_size = 0.05 , n_iter = 20, dim = 1):\n",
    "\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.step_size = step_size\n",
    "        self.n_iter = n_iter\n",
    "        self.dim = dim\n",
    "\n",
    "    def update(self, phy_lf, model):\n",
    "\n",
    "        x_new = torch.zeros(self.Nf, self.dim, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "        t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "        x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "        self.XTGrid = x_t_new\n",
    "            \n",
    "        x_data = self.XTGrid\n",
    "        samples = x_data.clone().detach().requires_grad_(True)\n",
    "    \n",
    "        for t in range(1, self.n_iter + 1):\n",
    "            grad = phy_lf(model, samples, self.device)\n",
    "            with torch.no_grad():\n",
    "                samples = samples + self.step_size * torch.sign(grad)\n",
    "                for i in range(samples.shape[1]):\n",
    "                    if i < int(samples.shape[1]-1):\n",
    "                        samples[:, i] = torch.clamp(samples[:, i], min=-1, max=1)  # x축 클램핑\n",
    "                    else:\n",
    "                        samples[:, i] = torch.clamp(samples[:, i], min=0, max=1)   # t축 클램핑\n",
    "                        \n",
    "            samples = samples.clone().detach().requires_grad_(True)\n",
    "        self.XTGrid = samples.detach()           \n",
    "\n",
    "\n",
    "class R3Sampler(nn.Module):\n",
    "    def __init__(self,Nf, fixed_uniform, device):\n",
    "        super(R3Sampler, self).__init__()\n",
    "        self.Nf = Nf\n",
    "        self.device = device\n",
    "        self.XTGrid = torch.tensor(copy.deepcopy(fixed_uniform), dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "    \n",
    "    def update(self, loss_aver, loss_ele):\n",
    "        with torch.no_grad():\n",
    "            cho_i = loss_ele > loss_aver\n",
    "            cho_i = cho_i.to('cpu')\n",
    "            self.XTGrid = self.XTGrid[cho_i].detach()\n",
    "            need_n_sample = self.Nf-self.XTGrid.shape[0]\n",
    "            x_new = torch.zeros(need_n_sample, self.XTGrid.shape[1]-1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "            t_new = torch.zeros(need_n_sample, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "            x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "            self.XTGrid = torch.concatenate((self.XTGrid, x_t_new), axis = 0)\n",
    "            self.XTGrid = torch.tensor(self.XTGrid, dtype = torch.float32, device=self.device, requires_grad=True)\n",
    "    \n",
    "class PINN(nn.Module):\n",
    "    def __init__(self,k , c , t, exact_XYT, exact_u, LBs, UBs, Layers, N0, Nb, Nf, \n",
    "                 Activation = nn.Tanh(), \n",
    "                 model_name = \"PINN.model\", device = 'cpu',\n",
    "                  display_freq = 100, samp = 'fixed' ):\n",
    "        \n",
    "        super(PINN, self).__init__()\n",
    "        \n",
    "        \n",
    "        self.LBs = torch.tensor(LBs, dtype=torch.float32)\n",
    "        self.UBs = torch.tensor(UBs, dtype=torch.float32)\n",
    "        \n",
    "        self.Layers = Layers\n",
    "        self.in_dim  = Layers[0]\n",
    "        self.out_dim = Layers[-1]\n",
    "        self.Activation = Activation\n",
    "        \n",
    "        self.device = device\n",
    "        \n",
    "        x_init = torch.zeros(Nf, len(self.LBs) - 1, dtype = torch.float32, device=self.device).uniform_(0, UBs[0])\n",
    "        t_init = torch.zeros(Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, UBs[-1])\n",
    "        x_t_init = torch.concatenate((x_init,t_init), axis = 1)\n",
    "        self.fixed_uniform = torch.tensor(x_t_init, dtype = torch.float32, device=self.device, requires_grad=True)\n",
    "        \n",
    "        self.N0 = N0\n",
    "        self.Nb = Nb\n",
    "        self.Nf = Nf\n",
    "        \n",
    "        self.t = t\n",
    "        self.exact_XYT = exact_XYT\n",
    "        self.exact_u = exact_u.reshape(-1,1)\n",
    "        \n",
    "        self.XT0, self.u0  = self.InitialCondition(self.N0, self.LBs, self.UBs)\n",
    "        self.boundary_set = self.BoundaryCondition(self.Nb, self.LBs, self.UBs)\n",
    "        \n",
    "        self.XT0 = self.XT0.to(device)\n",
    "        self.u0 = self.u0.to(device) \n",
    "        \n",
    "        self._nn = self.build_model()\n",
    "        self._nn.to(self.device)\n",
    "        self.Loss = torch.nn.MSELoss(reduction='mean')\n",
    "        \n",
    "        self.model_name = model_name\n",
    "        self.display_freq = display_freq\n",
    "        \n",
    "        self.k = k\n",
    "        self.c = c\n",
    "        self.X_star = self.exact_XYT\n",
    "        self.method = samp\n",
    "        \n",
    "        self.r3_sample = R3Sampler(self.Nf, self.fixed_uniform, device)\n",
    "        self.las = LASSampler(self.Nf, fixed_uniform=self.fixed_uniform, device=self.device, L_iter = 1, beta = 0.2, tau=1e-2)\n",
    "        self.l_inf = L_INFSampler(self.Nf, device=self.device, step_size = 0.05 , n_iter = 20, dim = len(self.LBs) - 1)\n",
    "        self.rad = RADSampler(Nf = self.Nf, device=self.device, k = self.k, c=self.c, dim = len(self.LBs) - 1)\n",
    "        \n",
    "    \n",
    "    def build_model(self):\n",
    "        Seq = nn.Sequential()\n",
    "        for ii in range(len(self.Layers)-1):\n",
    "            this_module = nn.Linear(self.Layers[ii], self.Layers[ii+1])\n",
    "            nn.init.xavier_normal_(this_module.weight)\n",
    "            Seq.add_module(\"Linear\" + str(ii), this_module)\n",
    "            if not ii == len(self.Layers)-2:\n",
    "                Seq.add_module(\"Activation\" + str(ii), self.Activation)\n",
    "        return Seq\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self._nn.forward(x.to(self.device)).reshape(-1)\n",
    "\n",
    "    def InitialCondition(self, N0, LB, UB):\n",
    "        n_per_dim = int(np.round(N0**(1/len(LB[:-1]))))\n",
    "#         n_per_dim = int(np.round(np.sqrt(N0)))\n",
    "        nodes = []\n",
    "        for i in range(len(LB)-1):\n",
    "            X = torch.linspace(LB[i], UB[i], n_per_dim).float()\n",
    "            nodes.append(X)\n",
    "            \n",
    "        t_in = torch.tensor([LB[-1]]).float()\n",
    "        nodes.append(t_in)\n",
    "        \n",
    "        XT_in = stacked_grid(*nodes)\n",
    "        u0 = inital_soln(XT_in)\n",
    "        return XT_in, u0\n",
    "\n",
    "    def ICLoss(self):\n",
    "        uv0_pred = self.forward(self.XT0)\n",
    "        loss = self.Loss(uv0_pred, self.u0.to(self.device))\n",
    "        return loss \n",
    "\n",
    "    def BoundaryPoints(self, nb, boundary_index, LBs, UBs, boundary_type):\n",
    "        \"\"\"\n",
    "        Generalized function to get boundary points in an N-dimensional space.\n",
    "        - `boundary_index`: Index of the dimension where the boundary is applied.\n",
    "        - `nb`: Number of boundary points.\n",
    "        - `LBs`, `UBs`: Lower and upper bounds for each dimension.\n",
    "        - `boundary_type`: 'lower' for the lower boundary, 'upper' for the upper boundary.\n",
    "        \"\"\"\n",
    "        n_per_dim = int(np.round(nb ** (1 / (len(LBs) - 1))))  # Distribute points across remaining dimensions\n",
    "        \n",
    "        # Create grid ranges for each dimension\n",
    "        grid_ranges = []\n",
    "        for i in range(len(LBs)  - 1):  # Excluding time\n",
    "            if i == boundary_index:\n",
    "                if boundary_type == 'lower':\n",
    "                    grid_ranges.append(torch.tensor([LBs[i]]).float().to(self.device))  # Fixed lower boundary position\n",
    "                else:\n",
    "                    grid_ranges.append(torch.tensor([UBs[i]]).float().to(self.device))  # Fixed upper boundary position\n",
    "            else:\n",
    "                grid_ranges.append(torch.linspace(LBs[i], UBs[i], n_per_dim).float().to(self.device))\n",
    "\n",
    "        # Time dimension (last dimension)\n",
    "        grid_ranges.append(torch.linspace(LBs[-1], UBs[-1], nb).float().to(self.device))\n",
    "\n",
    "        return stacked_grid(*grid_ranges)\n",
    "\n",
    "    \n",
    "    def BoundaryCondition(self, Nb, LBs, UBs):\n",
    "        \"\"\"\n",
    "        Generalized function to get all boundary conditions in an N-dimensional space.\n",
    "        - `Nb`: Total number of boundary points.\n",
    "        - `LBs`, `UBs`: Lower and upper bounds for each dimension.\n",
    "        \"\"\"\n",
    "        nb = int(np.round(Nb / (2 * (len(LBs) - 1))))  # Divide points across boundary faces\n",
    "        boundary_sets = []\n",
    "        \n",
    "        for d in range(len(LBs) - 1):  # Iterate over all spatial dimensions\n",
    "            lower_boundary = self.BoundaryPoints(nb, d, LBs, UBs, 'lower')\n",
    "            upper_boundary = self.BoundaryPoints(nb, d, LBs, UBs, 'upper')\n",
    "            boundary_sets.append((lower_boundary, upper_boundary))\n",
    "\n",
    "        return boundary_sets  # Returns pairs of boundaries per dimension\n",
    "\n",
    "    def BCLoss(self):\n",
    "        \"\"\"\n",
    "        Compute boundary condition loss given a model, exact solution, and loss function.\n",
    "        \"\"\"\n",
    "        total_loss = 0\n",
    "        for lower, upper in self.boundary_set:\n",
    "            U_L, U_R = self.forward(lower), self.forward(upper)\n",
    "            UL_exact, UR_exact = exact_soln(lower), exact_soln(upper)\n",
    "\n",
    "            total_loss += self.Loss(U_L, UL_exact) + self.Loss(U_R, UR_exact)\n",
    "        \n",
    "        return total_loss\n",
    "    \n",
    "    def PhysicsLoss(self, XTGrid):\n",
    "        XTGrid = XTGrid.requires_grad_(True).to(device)\n",
    "        u = self.forward(XTGrid)\n",
    "\n",
    "        # Compute first derivatives\n",
    "        u_grad = torch.autograd.grad(outputs=u, \n",
    "                                     inputs=XTGrid, \n",
    "                                     grad_outputs=torch.ones_like(u).to(device), \n",
    "                                     create_graph=True, \n",
    "                                     allow_unused=True)[0]\n",
    "\n",
    "        # Compute second derivatives for each spatial dimension\n",
    "        u_laplacian = torch.zeros_like(u, device=device)\n",
    "        for d in range(XTGrid.shape[1] - 1):  # Assuming last dim is time\n",
    "            ux = u_grad[:, d]\n",
    "            uxx = torch.autograd.grad(outputs=ux, \n",
    "                                      inputs=XTGrid, \n",
    "                                      grad_outputs=torch.ones_like(ux).to(device), \n",
    "                                      create_graph=True, \n",
    "                                      allow_unused=True)[0][:, d]\n",
    "            u_laplacian += uxx\n",
    "            \n",
    "        f = f_function(XTGrid)\n",
    "\n",
    "        # Time derivative\n",
    "        ut = u_grad[:, -1]  # Assuming last dimension is time\n",
    "\n",
    "        loss2 = (u_laplacian - ut+f.flatten())**2\n",
    "        loss1 = loss2.mean()\n",
    "\n",
    "        return loss1, loss2 \n",
    "\n",
    "    def Train(self, n_iters, weights=(1.0,1.0,1.0)):\n",
    "        params = list(self.parameters())\n",
    "        optimizer = optim.Adam(params, lr=1e-3)\n",
    "        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 5000, gamma=0.9, last_epoch=-1)\n",
    "        min_loss = 999999.0\n",
    "        Training_Losses = [-10]*n_iters\n",
    "        Test_Losses = []\n",
    "        rel_error = [-10]*(1+n_iters//10)\n",
    "        \n",
    "        for jj in range(n_iters):\n",
    "            Total_ICLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            Total_BCLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            Total_PhysicsLoss = torch.tensor(0.0, dtype = torch.float32, device=self.device, requires_grad = True)\n",
    "            \n",
    "            Total_ICLoss = Total_ICLoss + self.ICLoss()\n",
    "            Total_BCLoss = Total_BCLoss + self.BCLoss()\n",
    "            \n",
    "            if self.method =='r3':\n",
    "                if jj == 0:\n",
    "                    XTGrid = self.r3_sample.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                else:\n",
    "                    with torch.no_grad():\n",
    "                        self.r3_sample.update(loss1, loss2)\n",
    "                        XTGrid = self.r3_sample.XTGrid\n",
    "                        XTGrid = torch.tensor(XTGrid, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                        \n",
    "            elif self.method == 'las':\n",
    "                if jj == 0:\n",
    "                    XTGrid = self.las.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                else:\n",
    "                    if self.las.cnt % 1 == 0:# 4,6,8,10 cnt = 4, \n",
    "                        self.las.update(cal_domain_grad, self._nn)\n",
    "                    XTGrid = self.las.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                self.las.cnt += 1\n",
    "            \n",
    "            elif self.method =='l_inf':\n",
    "                    self.l_inf.update(cal_domain_grad, self._nn)\n",
    "                    XTGrid = self.l_inf.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)\n",
    "                \n",
    "            elif self.method =='rad':\n",
    "                    self.rad.update(self._nn)\n",
    "                    XTGrid = self.rad.XTGrid\n",
    "                    XTGrid = torch.tensor(XTGrid, dtype=torch.float32, requires_grad=True).to(self.device)   \n",
    "            \n",
    "            elif self.method =='fixed':\n",
    "                    XTGrid = torch.tensor(self.fixed_uniform, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "            \n",
    "            elif self.method =='random-r':\n",
    "                    x_new = torch.zeros(self.Nf, len(self.LBs) - 1, dtype = torch.float32, device=self.device).uniform_(-1, 1)\n",
    "                    t_new = torch.zeros(self.Nf, 1, dtype = torch.float32, device=self.device).uniform_(0, 1)\n",
    "                    x_t_new = torch.concatenate((x_new,t_new), axis = 1)\n",
    "                    XTGrid = torch.tensor(x_t_new, dtype = torch.float32, device=self.device, requires_grad=True) \n",
    "                \n",
    "            optimizer.zero_grad()    \n",
    "            loss1, loss2 = self.PhysicsLoss(XTGrid) # For r3 method, loss2 contains element-wise errors\n",
    "            \n",
    "            Total_PhysicsLoss = Total_PhysicsLoss + loss1\n",
    "            Total_Loss = weights[0]*Total_ICLoss + weights[1]*Total_BCLoss\\\n",
    "                        + weights[2]*Total_PhysicsLoss \n",
    "            \n",
    "            Total_Loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            if Total_Loss < min_loss:\n",
    "                torch.save(self._nn.state_dict(), \"../models/\"+self.method+'_'+str(len(self.Layers)-2)+'_'+str(self.Nf)+'.pt')\n",
    "                min_loss = float(Total_Loss)\n",
    "                    \n",
    "            Training_Losses[jj] = float(Total_Loss)\n",
    "            \n",
    "            if (jj+1) % self.display_freq == 0:\n",
    "                with torch.no_grad():\n",
    "                    outputs = self.forward(self.X_star)\n",
    "                    outputs = outputs.reshape(-1,1)\n",
    "                    re = np.linalg.norm(self.exact_u.cpu()-outputs.cpu().detach()) / np.linalg.norm(self.exact_u.cpu().detach())\n",
    "                    rel_error[int((jj+1)/10)] = float(re*100)\n",
    "                    \n",
    "                print(\"Iteration Number = {}\".format(jj+1))\n",
    "                print(\"\\tIC Loss = {}\".format(float(Total_ICLoss)))\n",
    "                print(\"\\tBC Loss = {}\".format(float(Total_BCLoss)))\n",
    "                print(\"\\tPhysics Loss = {}\".format(float(Total_PhysicsLoss)))\n",
    "                print(\"\\tTraining Loss = {}\".format(float(Total_Loss)))\n",
    "                print(\"\\tRelative L2 error (test) = {}\".format(float(re*100)))\n",
    "\n",
    "        return Training_Losses, rel_error\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "                \n",
    "        parser = argparse.ArgumentParser()\n",
    "        parser.add_argument('--nodes', type=int, default = 128, help='The number of nodes per hidden layer in the neural network')\n",
    "        parser.add_argument('--layers', type=int, default = 8, help='The number of hidden layers in the neural network')\n",
    "        parser.add_argument('--N0', type=int, default = 450, help='The number of points to use on the initial condition')\n",
    "        parser.add_argument('--Nb', type=int, default = 450, help='The number of points to use on the boundary condition')\n",
    "        parser.add_argument('--Nf', type=int, default = 1000, help='The number of collocation points to use')\n",
    "        parser.add_argument('--epochs', type=int, default = 10000, help='The number of epochs to train the neural network')\n",
    "        parser.add_argument('--method', type=str, default='random-r', help='Sampling method') # fixed, random-r, rad, r3, l_inf, las \n",
    "        parser.add_argument('--model-name', type=str, default='PINN_model', help='File name to save the model')\n",
    "        parser.add_argument('--display-freq', type=int, default=10, help='How often to display loss information')\n",
    "        parser.add_argument('-f')\n",
    "        args = parser.parse_args()\n",
    "\n",
    "        if not os.path.exists(\"../models/\"):\n",
    "            os.mkdir(\"../models/\")\n",
    "            \n",
    "        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "        NHiddenLayers = args.layers\n",
    "\n",
    "        desired_dim = 8\n",
    "    \n",
    "        LBs = [-1]*desired_dim\n",
    "        UBs = [1]*desired_dim\n",
    "\n",
    "        LBs.append(0)\n",
    "        UBs.append(1)\n",
    "\n",
    "        Nx = 5\n",
    "        Nt = 5\n",
    "        mesh_nodes = []\n",
    "    \n",
    "        for i in range(len(LBs)-1):\n",
    "            X = torch.linspace(LBs[i], UBs[i], Nx).float()\n",
    "            mesh_nodes.append(X)\n",
    "        \n",
    "        T = torch.linspace(LBs[-1], UBs[-1], Nt).float()\n",
    "        mesh_nodes.append(T)\n",
    "        \n",
    "        XYT = stacked_grid(*mesh_nodes)\n",
    "        Exact_U = exact_soln(XYT)\n",
    "        Activation = nn.Tanh()\n",
    "    \n",
    "        Layers = [desired_dim+1] + [args.nodes]*NHiddenLayers + [1]\n",
    "\n",
    "        k = 1\n",
    "        c = 1\n",
    "\n",
    "        repeat = [0, 1, 2, 3, 4]\n",
    "        for i in repeat:\n",
    "            pinn = PINN(  k = k,\n",
    "                          c = c,\n",
    "                          t= T,\n",
    "                          exact_XYT = XYT,\n",
    "                          exact_u = Exact_U,\n",
    "                          Layers = Layers,\n",
    "                          LBs = LBs,\n",
    "                          UBs = UBs,\n",
    "                          N0 = args.N0,\n",
    "                          Nb = args.Nb,\n",
    "                          Nf = args.Nf,\n",
    "                          Activation = Activation,\n",
    "                          device = device,\n",
    "                          model_name = \"../models/\" + args.model_name + \".model_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i),\n",
    "                          display_freq = args.display_freq, samp = args.method )\n",
    "\n",
    "            Losses_train, Losses_rel_l2 = pinn.Train(args.epochs, weights = (1, 1, 1)) # initial, boundary, residual\n",
    "\n",
    "            torch.save(Losses_train, \"../models/\" + args.model_name + \".loss_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i))\n",
    "            torch.save(Losses_rel_l2, \"../models/\" + args.model_name + \".rel_l2_\"+args.method+'_'+str(args.layers)+'_'+str(args.Nf)+'_'+str(i))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
