{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Allen Cahn PDE\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Allen Cahn PDE with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\epsilon\\Delta u + \\gamma(u - u^3), \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &u(0, x) = u_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\epsilon > 0$ is the defintion , and $[-L, L]$ covers one full period, i.e. $T = 2L$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from the system\n",
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "import time\n",
    "# additional computing packages\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "# for collocation points\n",
    "from pyDOE import lhs\n",
    "# for DNN training\n",
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "# for plotting\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "# set the random seed\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "# MPS or CUDA or CPU\n",
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "epsilon = 1e-4\n",
    "gamma = 5.0\n",
    "L = 1.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 1.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "L_ten = torch.tensor(L).float().to(device)\n",
    "u0 = lambda x: np.power(x, 2.0) * np.cos(np.pi * x)\n",
    "u0_ten = lambda x: torch.pow(x, 2.0) * torch.cos(pi_ten * x)\n",
    "# please use chebfun to solve for the true solution (super easy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the physics-guided neural network\n",
    "# we will use built-in transformation to make it \n",
    "# periodic by default\n",
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, epsilon, gamma, X_IC, u_IC, X_BC, X_PDE, layers):\n",
    "        # IC data point\n",
    "        self.t_IC = torch.tensor(X_IC[:, 0:1]).float().to(device)\n",
    "        self.x_IC = torch.tensor(X_IC[:, 1:2]).float().to(device)\n",
    "        self.u_IC = torch.tensor(u_IC).float().to(device)\n",
    "        # BC data point\n",
    "        self.t_BC = torch.tensor(X_BC[:, 0:1]).float().to(device)\n",
    "        self.x_BC = torch.tensor(X_BC[:, 1:2]).float().to(device)\n",
    "        self.period = torch.tensor(period).float().to(device)\n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "        # equation related parameters\n",
    "        self.epsilon = torch.tensor(epsilon).float().to(device)\n",
    "        self.gamma = torch.tensor(gamma).float().to(device)\n",
    "        # layers to build Neural Net\n",
    "        self.layers = layers\n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        # add a learning rate scheduler\n",
    "        self.scheduler = lr_scheduler.ExponentialLR(self.optimizer_Adam, gamma=0.99)\n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=50000, \n",
    "            max_eval=50000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-8, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )        \n",
    "        self.iter = 0\n",
    "    # evaluater neural network, no transformation\n",
    "    def NN_eval(self, t, x):  \n",
    "        return self.dnn(torch.cat([t, x], dim = 1))\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        Eq  = u_t - self.epsilon * u_xx - self.gamma * (u - torch.pow(u, 3.0))       \n",
    "        return Eq\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "        # compute IC loss\n",
    "        IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "        loss_IC = torch.mean(torch.square(IC_pred - self.u_IC))\n",
    "        # compute PBC loss\n",
    "        BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "        BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "        loss_BC = torch.mean(torch.square(BC_pred_left - BC_pred_right))\n",
    "        # compute PDE loss\n",
    "        pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE = torch.mean(torch.square(pde_pred))    \n",
    "        # compute the total loss, it can be weighted\n",
    "        loss = loss_IC + loss_BC + loss_PDE\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            end_time = time.time()\n",
    "            print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (self.iter, loss.item(), end_time - self.start_time))\n",
    "            print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "            self.start_time = end_time\n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the timer\n",
    "        start_time = time.time()        \n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        print('Starting with Adam')\n",
    "        for epoch in range(nIter):\n",
    "            # compute IC loss\n",
    "            IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "            loss_IC = torch.mean(torch.square(IC_pred - self.u_IC))\n",
    "            # compute PBC loss\n",
    "            BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "            BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "            loss_BC = torch.mean(torch.square(BC_pred_left - BC_pred_right))            \n",
    "            # compute PDE loss\n",
    "            pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE = torch.mean(torch.square(pde_pred))  \n",
    "            # compute the total loss, it can be weighted\n",
    "            loss = loss_IC + loss_BC + loss_PDE\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                end_time = time.time()\n",
    "                print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (epoch + 1, loss.item(), end_time - start_time))\n",
    "                print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "                start_time = end_time\n",
    "                # change the learning rate\n",
    "                self.scheduler.step()\n",
    "        # Using the second-order L-BFGS optimizer\n",
    "        print('Starting with L-BFGS')\n",
    "        self.start_time = time.time()\n",
    "        self.optimizer_LBFGS.step(self.loss_func)                \n",
    "    #        \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        return u\n",
    "\n",
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct the points\n",
    "def get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE):\n",
    "    # for IC, (t, x) \\in {tlo}x\\Omega\n",
    "    x_IC = np.expand_dims(np.linspace(xlo, xhi, N_IC), axis = 1)\n",
    "    t_IC = tlo * np.ones_like(x_IC)\n",
    "    ptsIC = np.hstack((t_IC, x_IC))\n",
    "    u_IC = u0(x_IC)\n",
    "    # for BC, (t, x) \\in {xlo}x(0, T]\n",
    "    t_BC = np.linspace(xlo, xhi, N_BC + 1)\n",
    "    t_BC = np.expand_dims(t_BC[1:], axis = 1)\n",
    "    x_BC = xlo * np.ones_like(t_BC)\n",
    "    ptsBC = np.hstack((t_BC, x_BC))\n",
    "    # for collocation pts, (t, x) \\in (0, T)x\\Omega\n",
    "    pts_rand = lhs(2, N_PDE)\n",
    "    t_PDE = tlo + (thi - tlo) * pts_rand[:, 0:1]\n",
    "    x_PDE = xlo + (xhi - xlo) * pts_rand[:, 1:2]\n",
    "    ptsPDE = np.hstack((t_PDE, x_PDE))\n",
    "    return ptsIC, u_IC, ptsBC, ptsPDE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting with Adam\n",
      "Iter  1000, Total: 1.4048e-01, Time: 52.39 secs\n",
      "IC: 1.2622e-01, PBC: 3.5302e-05, PDE: 1.4223e-02\n"
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 32, 1]\n",
    "N_IC = 64\n",
    "N_BC = 64\n",
    "N_PDE = int(N_IC * N_BC)\n",
    "ptsIC, u_IC, ptsBC, ptsPDE = get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE)\n",
    "model = PhysicsInformedNN(period, epsilon, gamma, ptsIC, u_IC, ptsBC, ptsPDE, layers)\n",
    "model.train(10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply PINN to the same grid as the quadrature solution for comparison\n",
    "t = np.linspace(tlo, thi, 101)\n",
    "x = np.linspace(xlo, xhi, 101)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred = model.predict(pts_flat)\n",
    "# #\n",
    "# Exact = u_quad.T\n",
    "# Exact_vec = Exact.flatten()[:, None]\n",
    "# error_u = np.linalg.norm(Exact_vec-u_pred,2)/np.linalg.norm(Exact_vec,2)\n",
    "# print('Error u: %e' % (error_u))                     \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')\n",
    "# Error = np.abs(Exact - U_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\" The aesthetic setting has changed. \"\"\"\n",
    "\n",
    "####### Row 0: u(t,x) ##################    \n",
    "\n",
    "fig = plt.figure(figsize=(11, 5))\n",
    "ax = fig.add_subplot(1, 1, 1)\n",
    "#\n",
    "ax.plot(ptsPDE[:, 0], ptsPDE[:, 1], \n",
    "    'rd', label = 'PDE Data (%d points)' % (ptsPDE.shape[0]), \n",
    "    markersize = 4,  # marker size doubled\n",
    "    clip_on = False,\n",
    "    alpha=1.0\n",
    ")\n",
    "#\n",
    "ax.set_xlabel('$t$', size=15)\n",
    "ax.set_ylabel('$x$', size=15)\n",
    "ax.legend(\n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.9, -0.05), \n",
    "    ncol=5, \n",
    "    frameon=False, \n",
    "    prop={'size': 15}\n",
    ")\n",
    "ax.legend()\n",
    "ax.set_title('Points', fontsize = 15) # font size doubled\n",
    "ax.tick_params(labelsize=12)\n",
    "#\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "####### Row 1: u(t,x) slices ################## \n",
    "\n",
    "\"\"\" The aesthetic setting has changed. \"\"\"\n",
    "\n",
    "fig = plt.figure(figsize=(14, 10))\n",
    "#\n",
    "ax = plt.subplot(1, 3, 1)\n",
    "#ax.plot(x, Exact[:, 25], 'bo-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 0], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')    \n",
    "ax.set_title('$t = %.1f$' %(t[0]), fontsize = 15)\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-0.1,1.1])\n",
    "# ax.set_ylim([-0.1,1.1]) \n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "ax = plt.subplot(1, 3, 2)\n",
    "#ax.plot(x,Exact[:,50], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 50], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-L,L])\n",
    "# ax.set_ylim([-0.1,1.1]) \n",
    "ax.set_title('$t = %.1f$' %(t[50]), fontsize = 15)\n",
    "ax.legend(\n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.5, -0.15), \n",
    "    ncol=5, \n",
    "    frameon=False, \n",
    "    prop={'size': 15}\n",
    ")\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "ax = plt.subplot(1, 3, 3)\n",
    "#ax.plot(x,Exact[:,75], 'b-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 100], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')\n",
    "#ax.axis('square')\n",
    "# ax.set_xlim([-0.1,1.1])\n",
    "# ax.set_ylim([-0.1,1.1])    \n",
    "ax.set_title('$t = %.1f$' %(t[100]), fontsize = 15)\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(15)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
