{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Allen Cahn PDE\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Allen Cahn PDE with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\epsilon\\Delta u + \\gamma(u - u^3), \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &u(0, x) = u_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\epsilon > 0$ is the defintion , and $[-L, L]$ covers one full period, i.e. $T = 2L$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from the system\n",
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "import time\n",
    "# additional computing packages\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "# for collocation points\n",
    "from pyDOE import lhs\n",
    "# for DNN training\n",
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "# for plotting\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "# set the random seed\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def figsize(scale, nplots = 1):\n",
    "    fig_width_pt = 390.0                          # Get this from LaTeX using \\the\\textwidth\n",
    "    inches_per_pt = 1.0/72.27                       # Convert pt to inch\n",
    "    golden_mean = (np.sqrt(5.0)-1.0)/2.0            # Aesthetic ratio (you could change this)\n",
    "    fig_width = fig_width_pt*inches_per_pt*scale    # width in inches\n",
    "    fig_height = nplots*fig_width*golden_mean       # height in inches\n",
    "    fig_size = [fig_width,fig_height]\n",
    "    return fig_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "pgf_with_latex = {                      # setup matplotlib to use latex for output\n",
    "    \"pgf.texsystem\": \"pdflatex\",        # change this if using xetex or lautex\n",
    "    \"text.usetex\": True,                # use LaTeX to write all text\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.serif\": [],                   # blank entries should cause plots to inherit fonts from the document\n",
    "    \"font.sans-serif\": [],\n",
    "    \"font.monospace\": [],\n",
    "    \"axes.labelsize\": 10,               # LaTeX default is 10pt font.\n",
    "    \"font.size\": 10,\n",
    "    \"legend.fontsize\": 8,               # Make the legend/label fonts a little smaller\n",
    "    \"xtick.labelsize\": 8,\n",
    "    \"ytick.labelsize\": 8,\n",
    "    \"figure.figsize\": figsize(1.0),     # default fig size of 0.9 textwidth\n",
    "    \"pgf.preamble\": r\"\\usepackage[utf8x]{inputenc} \\usepackage[T1]{fontenc}\"\n",
    "    # use utf8 fonts becasue your computer can handle it :)\n",
    "    # plots will be generated using this preamble\n",
    "    }\n",
    "mpl.rcParams.update(pgf_with_latex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "# MPS or CUDA or CPU\n",
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Problem Definition\n",
    "\"\"\"\n",
    "# define grid for quadrature solution\n",
    "# the following parameters and initial condition/PBC are taken from the exmaple\n",
    "# in spin('GS') from chebfun\n",
    "epsilon = 1e-4\n",
    "gamma = 5.0\n",
    "L = 1.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 1.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "L_ten = torch.tensor(L).float().to(device)\n",
    "u0 = lambda x: np.power(x, 2.0) * np.cos(np.pi * x)\n",
    "u0_ten = lambda x: torch.pow(x, 2.0) * torch.cos(pi_ten * x)\n",
    "# please use chebfun to solve for the true solution (super easy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the physics-guided neural network\n",
    "# we will use built-in transformation to make it \n",
    "# periodic by default\n",
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, epsilon, gamma, X_IC, u_IC, X_BC, X_PDE, layers):\n",
    "        # IC data point\n",
    "        self.t_IC = torch.tensor(X_IC[:, 0:1]).float().to(device)\n",
    "        self.x_IC = torch.tensor(X_IC[:, 1:2]).float().to(device)\n",
    "        self.u_IC = torch.tensor(u_IC).float().to(device)\n",
    "        N_IC = X_IC.shape[0]\n",
    "        self.LW_IC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_IC, 1).float(), requires_grad=True).to(device)])\n",
    "        # BC data point\n",
    "        self.t_BC = torch.tensor(X_BC[:, 0:1]).float().to(device)\n",
    "        self.x_BC = torch.tensor(X_BC[:, 1:2]).float().to(device)\n",
    "        self.period = torch.tensor(period).float().to(device)\n",
    "        N_BC = X_BC.shape[0]\n",
    "        self.LW_BC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_BC, 1).float(), requires_grad=True).to(device)])        \n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "        N_PDE = X_PDE.shape[0]\n",
    "        self.LW_PDE = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_PDE, 1).float(), requires_grad=True).to(device)])\n",
    "        # equation related parameters\n",
    "        self.epsilon = torch.tensor(epsilon).float().to(device)\n",
    "        self.gamma = torch.tensor(gamma).float().to(device)\n",
    "        # layers to build Neural Net\n",
    "        self.layers = layers\n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        self.optimizer_LW_IC = torch.optim.Adam(self.LW_IC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_BC = torch.optim.Adam(self.LW_BC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_PDE = torch.optim.Adam(self.LW_PDE.parameters(), lr = 5e-3)\n",
    "        # add a learning rate scheduler\n",
    "        self.scheduler = lr_scheduler.ExponentialLR(self.optimizer_Adam, gamma=0.999)\n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=50000, \n",
    "            max_eval=50000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-8, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )        \n",
    "        self.iter = 0\n",
    "    # evaluater neural network, no transformation\n",
    "    def NN_eval(self, t, x):  \n",
    "        return self.dnn(torch.cat([t, x], dim = 1))\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        Eq  = u_t - self.epsilon * u_xx - self.gamma * (u - torch.pow(u, 3.0))       \n",
    "        return Eq\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "        # compute IC loss\n",
    "        IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "        loss_IC = torch.mean(torch.square(self.LW_IC[0] * (IC_pred - self.u_IC)))\n",
    "        # compute PBC loss\n",
    "        BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "        BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "        loss_BC = torch.mean(torch.square(self.LW_BC[0] * (BC_pred_left - BC_pred_right)))\n",
    "        # compute PDE loss\n",
    "        pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE = torch.mean(torch.square(self.LW_PDE[0] * pde_pred))    \n",
    "        # compute the total loss, it can be weighted\n",
    "        loss = loss_IC + loss_BC + loss_PDE\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            end_time = time.time()\n",
    "            print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (self.iter, loss.item(), end_time - self.start_time))\n",
    "            print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "            self.start_time = end_time\n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the timer\n",
    "        start_time = time.time()        \n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        print('Starting with Adam')\n",
    "        for epoch in range(nIter):\n",
    "            # compute IC loss\n",
    "            IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "            loss_IC = torch.mean(torch.square(self.LW_IC[0] * (IC_pred - self.u_IC)))\n",
    "            # compute PBC loss\n",
    "            BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "            BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "            loss_BC = torch.mean(torch.square(self.LW_BC[0] * (BC_pred_left - BC_pred_right)))          \n",
    "            # compute PDE loss\n",
    "            pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE = torch.mean(torch.square(self.LW_PDE[0] * pde_pred))   \n",
    "            # compute the total loss, it can be weighted\n",
    "            loss = loss_IC + loss_BC + loss_PDE\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            self.optimizer_LW_IC.zero_grad()\n",
    "            self.optimizer_LW_BC.zero_grad()\n",
    "            self.optimizer_LW_PDE.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            self.LW_IC[0].grad.data = -self.LW_IC[0].grad.data\n",
    "            self.LW_BC[0].grad.data = -self.LW_BC[0].grad.data\n",
    "            self.LW_PDE[0].grad.data = -self.LW_PDE[0].grad.data\n",
    "            self.optimizer_LW_IC.step()\n",
    "            self.optimizer_LW_BC.step()\n",
    "            self.optimizer_LW_PDE.step()\n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                end_time = time.time()\n",
    "                print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (epoch + 1, loss.item(), end_time - start_time))\n",
    "                print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "                print('For IC,  min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_IC[0]).item(), torch.max(self.LW_IC[0]).item()))\n",
    "                print('For BC,  min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_BC[0]).item(), torch.max(self.LW_BC[0]).item()))\n",
    "                print('For PDE, min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_PDE[0]).item(), torch.max(self.LW_PDE[0]).item()))\n",
    "                start_time = end_time\n",
    "                # change the learning rate\n",
    "                self.scheduler.step()\n",
    "        # Using the second-order L-BFGS optimizer\n",
    "        print('Starting with L-BFGS')\n",
    "        self.start_time = time.time()\n",
    "        self.optimizer_LBFGS.step(self.loss_func)                \n",
    "    #        \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        return u\n",
    "\n",
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct the points\n",
    "def get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE):\n",
    "    # for IC, (t, x) \\in {tlo}x\\Omega\n",
    "    x_IC = np.expand_dims(np.linspace(xlo, xhi, N_IC), axis = 1)\n",
    "    t_IC = tlo * np.ones_like(x_IC)\n",
    "    ptsIC = np.hstack((t_IC, x_IC))\n",
    "    u_IC = u0(x_IC)\n",
    "    # for BC, (t, x) \\in {xlo}x(0, T]\n",
    "    t_BC = np.linspace(xlo, xhi, N_BC + 1)\n",
    "    t_BC = np.expand_dims(t_BC[1:], axis = 1)\n",
    "    x_BC = xlo * np.ones_like(t_BC)\n",
    "    ptsBC = np.hstack((t_BC, x_BC))\n",
    "    # for collocation pts, (t, x) \\in (0, T)x\\Omega\n",
    "    pts_rand = lhs(2, N_PDE)\n",
    "    t_PDE = tlo + (thi - tlo) * pts_rand[:, 0:1]\n",
    "    x_PDE = xlo + (xhi - xlo) * pts_rand[:, 1:2]\n",
    "    ptsPDE = np.hstack((t_PDE, x_PDE))\n",
    "    return ptsIC, u_IC, ptsBC, ptsPDE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting with Adam\n",
      "Iter  1000, Total: 3.3162e-01, Time: 18.38 secs\n",
      "IC: 3.0308e-01, PBC: 6.3171e-05, PDE: 2.8479e-02\n",
      "For IC,  min LW: 1.0151e+00, max LW: 1.5264e+00\n",
      "For BC,  min LW: 1.0187e+00, max LW: 1.4496e+00\n",
      "For PDE, min LW: 1.0104e+00, max LW: 1.6535e+00\n",
      "Iter  2000, Total: 6.3492e-01, Time: 18.32 secs\n",
      "IC: 5.7609e-01, PBC: 2.9124e-05, PDE: 5.8802e-02\n",
      "For IC,  min LW: 1.0601e+00, max LW: 2.1194e+00\n",
      "For BC,  min LW: 1.1797e+00, max LW: 1.7936e+00\n",
      "For PDE, min LW: 1.0308e+00, max LW: 2.3812e+00\n",
      "Iter  3000, Total: 1.0454e+00, Time: 18.62 secs\n",
      "IC: 9.4666e-01, PBC: 5.3909e-05, PDE: 9.8649e-02\n",
      "For IC,  min LW: 1.1245e+00, max LW: 2.7397e+00\n",
      "For BC,  min LW: 1.2594e+00, max LW: 2.2634e+00\n",
      "For PDE, min LW: 1.0522e+00, max LW: 3.0962e+00\n",
      "Iter  4000, Total: 1.5475e+00, Time: 18.57 secs\n",
      "IC: 1.3934e+00, PBC: 1.1227e-04, PDE: 1.5399e-01\n",
      "For IC,  min LW: 1.1496e+00, max LW: 3.3433e+00\n",
      "For BC,  min LW: 1.3785e+00, max LW: 2.7871e+00\n",
      "For PDE, min LW: 1.0794e+00, max LW: 3.8355e+00\n",
      "Iter  5000, Total: 2.0934e+00, Time: 18.38 secs\n",
      "IC: 1.8587e+00, PBC: 3.6479e-03, PDE: 2.3104e-01\n",
      "For IC,  min LW: 1.2829e+00, max LW: 3.9188e+00\n",
      "For BC,  min LW: 1.8025e+00, max LW: 3.2336e+00\n",
      "For PDE, min LW: 1.1383e+00, max LW: 4.7438e+00\n",
      "Iter  6000, Total: 2.7933e+00, Time: 18.67 secs\n",
      "IC: 2.5115e+00, PBC: 4.1096e-04, PDE: 2.8138e-01\n",
      "For IC,  min LW: 1.4370e+00, max LW: 4.4938e+00\n",
      "For BC,  min LW: 2.1005e+00, max LW: 3.4028e+00\n",
      "For PDE, min LW: 1.2433e+00, max LW: 5.1472e+00\n",
      "Iter  7000, Total: 3.5455e+00, Time: 18.70 secs\n",
      "IC: 3.1871e+00, PBC: 6.8279e-04, PDE: 3.5774e-01\n",
      "For IC,  min LW: 1.7766e+00, max LW: 5.0467e+00\n",
      "For BC,  min LW: 2.3878e+00, max LW: 3.6234e+00\n",
      "For PDE, min LW: 1.4277e+00, max LW: 5.4944e+00\n",
      "Iter  8000, Total: 4.3529e+00, Time: 18.61 secs\n",
      "IC: 3.8968e+00, PBC: 8.9119e-04, PDE: 4.5519e-01\n",
      "For IC,  min LW: 2.0070e+00, max LW: 5.6042e+00\n",
      "For BC,  min LW: 2.5942e+00, max LW: 3.8345e+00\n",
      "For PDE, min LW: 1.5202e+00, max LW: 5.9471e+00\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[36], line 7\u001b[0m\n\u001b[1;32m      5\u001b[0m ptsIC, u_IC, ptsBC, ptsPDE \u001b[38;5;241m=\u001b[39m get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE)\n\u001b[1;32m      6\u001b[0m model \u001b[38;5;241m=\u001b[39m PhysicsInformedNN(period, epsilon, gamma, ptsIC, u_IC, ptsBC, ptsPDE, layers)\n\u001b[0;32m----> 7\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[34], line 115\u001b[0m, in \u001b[0;36mPhysicsInformedNN.train\u001b[0;34m(self, nIter)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer_LW_PDE\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m    114\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_Adam\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \n\u001b[1;32m    116\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLW_IC[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mgrad\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLW_IC[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mgrad\u001b[38;5;241m.\u001b[39mdata\n\u001b[1;32m    117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLW_BC[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mgrad\u001b[38;5;241m.\u001b[39mdata \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLW_BC[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mgrad\u001b[38;5;241m.\u001b[39mdata\n",
      "File \u001b[0;32m~/venv-metal/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:68\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.with_counter.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     66\u001b[0m instance\u001b[38;5;241m.\u001b[39m_step_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     67\u001b[0m wrapped \u001b[38;5;241m=\u001b[39m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__get__\u001b[39m(instance, \u001b[38;5;28mcls\u001b[39m)\n\u001b[0;32m---> 68\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mwrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/venv-metal/lib/python3.9/site-packages/torch/optim/optimizer.py:373\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    368\u001b[0m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    369\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    370\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    371\u001b[0m             )\n\u001b[0;32m--> 373\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    374\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m    376\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "File \u001b[0;32m~/venv-metal/lib/python3.9/site-packages/torch/optim/optimizer.py:76\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     74\u001b[0m     torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m     75\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 76\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     77\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     78\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
      "File \u001b[0;32m~/venv-metal/lib/python3.9/site-packages/torch/optim/adam.py:138\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    130\u001b[0m \u001b[38;5;129m@_use_grad_for_differentiable\u001b[39m\n\u001b[1;32m    131\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m, closure\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m    132\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Performs a single optimization step.\u001b[39;00m\n\u001b[1;32m    133\u001b[0m \n\u001b[1;32m    134\u001b[0m \u001b[38;5;124;03m    Args:\u001b[39;00m\n\u001b[1;32m    135\u001b[0m \u001b[38;5;124;03m        closure (Callable, optional): A closure that reevaluates the model\u001b[39;00m\n\u001b[1;32m    136\u001b[0m \u001b[38;5;124;03m            and returns the loss.\u001b[39;00m\n\u001b[1;32m    137\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 138\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cuda_graph_capture_health_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    140\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m    141\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/venv-metal/lib/python3.9/site-packages/torch/optim/optimizer.py:309\u001b[0m, in \u001b[0;36mOptimizer._cuda_graph_capture_health_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    306\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m format_string\n\u001b[1;32m    308\u001b[0m \u001b[38;5;66;03m# Currently needed by Adam and AdamW\u001b[39;00m\n\u001b[0;32m--> 309\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_cuda_graph_capture_health_check\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    310\u001b[0m     \u001b[38;5;66;03m# Note [torch.compile x capturable]\u001b[39;00m\n\u001b[1;32m    311\u001b[0m     \u001b[38;5;66;03m# If we are compiling, we try to take the capturable path automatically by\u001b[39;00m\n\u001b[1;32m    312\u001b[0m     \u001b[38;5;66;03m# setting the flag to True during tracing. Due to this, we skip all the checks\u001b[39;00m\n\u001b[1;32m    313\u001b[0m     \u001b[38;5;66;03m# normally required for determining whether we can use CUDA graphs and\u001b[39;00m\n\u001b[1;32m    314\u001b[0m     \u001b[38;5;66;03m# shunt the responsibility to torch.inductor. This saves time during tracing\u001b[39;00m\n\u001b[1;32m    315\u001b[0m     \u001b[38;5;66;03m# since the checks are slow without sacrificing UX since inductor will warn\u001b[39;00m\n\u001b[1;32m    316\u001b[0m     \u001b[38;5;66;03m# later if CUDA graphs cannot be enabled, e.g.,\u001b[39;00m\n\u001b[1;32m    317\u001b[0m     \u001b[38;5;66;03m# https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.\u001b[39;00m\n\u001b[1;32m    318\u001b[0m     \u001b[38;5;66;03m# Thus, when compiling, inductor will determine if cudagraphs\u001b[39;00m\n\u001b[1;32m    319\u001b[0m     \u001b[38;5;66;03m# can be enabled based on whether there is input mutation or CPU tensors.\u001b[39;00m\n\u001b[1;32m    320\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_compiling() \u001b[38;5;129;01mand\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mbackends\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_built() \u001b[38;5;129;01mand\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_available():\n\u001b[1;32m    321\u001b[0m         capturing \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mis_current_stream_capturing()\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 32, 1]\n",
    "N_IC = 64\n",
    "N_BC = 64\n",
    "N_PDE = int(N_IC * N_BC)\n",
    "ptsIC, u_IC, ptsBC, ptsPDE = get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE)\n",
    "model = PhysicsInformedNN(period, epsilon, gamma, ptsIC, u_IC, ptsBC, ptsPDE, layers)\n",
    "model.train(10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply PINN to the same grid as the quadrature solution for comparison\n",
    "t = np.linspace(tlo, thi, 101)\n",
    "x = np.linspace(xlo, xhi, 101)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred = model.predict(pts_flat)\n",
    "# #\n",
    "# Exact = u_quad.T\n",
    "# Exact_vec = Exact.flatten()[:, None]\n",
    "# error_u = np.linalg.norm(Exact_vec-u_pred,2)/np.linalg.norm(Exact_vec,2)\n",
    "# print('Error u: %e' % (error_u))                     \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')\n",
    "# Error = np.abs(Exact - U_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" The aesthetic setting has changed. \"\"\"\n",
    "\n",
    "####### Row 0: u(t,x) ##################    \n",
    "\n",
    "fig = plt.figure(figsize=(11, 5))\n",
    "ax = fig.add_subplot(1, 1, 1)\n",
    "#\n",
    "ax.plot(ptsPDE[:, 0], ptsPDE[:, 1], \n",
    "    'rd', label = 'PDE Data (%d points)' % (ptsPDE.shape[0]), \n",
    "    markersize = 4,  # marker size doubled\n",
    "    clip_on = False,\n",
    "    alpha=1.0\n",
    ")\n",
    "#\n",
    "ax.set_xlabel('$t$', size=15)\n",
    "ax.set_ylabel('$x$', size=15)\n",
    "ax.legend(\n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.9, -0.05), \n",
    "    ncol=5, \n",
    "    frameon=False, \n",
    "    prop={'size': 15}\n",
    ")\n",
    "ax.legend()\n",
    "ax.set_title('Points', fontsize = 15) # font size doubled\n",
    "ax.tick_params(labelsize=12)\n",
    "#\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####### Row 1: u(t,x) slices ################## \n",
    "\n",
    "\"\"\" The aesthetic setting has changed. \"\"\"\n",
    "\n",
    "fig = plt.figure(figsize=(14, 10))\n",
    "#\n",
    "ax = plt.subplot(1, 3, 1)\n",
    "#ax.plot(x, Exact[:, 25], 'bo-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 0], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')    \n",
    "ax.set_title('$t = %.1f$' %(t[0]), fontsize = 15)\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-0.1,1.1])\n",
    "# ax.set_ylim([-0.1,1.1]) \n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "ax = plt.subplot(1, 3, 2)\n",
    "#ax.plot(x,Exact[:,50], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 50], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-L,L])\n",
    "# ax.set_ylim([-0.1,1.1]) \n",
    "ax.set_title('$t = %.1f$' %(t[50]), fontsize = 15)\n",
    "ax.legend(\n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.5, -0.15), \n",
    "    ncol=5, \n",
    "    frameon=False, \n",
    "    prop={'size': 15}\n",
    ")\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "ax = plt.subplot(1, 3, 3)\n",
    "#ax.plot(x,Exact[:,75], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 100], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-0.1,1.1])\n",
    "# ax.set_ylim([-0.1,1.1])    \n",
    "ax.set_title('$t = %.1f$' %(t[100]), fontsize = 15)\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
