{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Gray Scott PDEs\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Gray Scott PDEs with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\epsilon_1\\Delta u + b(1 - u) - uv^2, \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &v_t = \\epsilon_2\\Delta v -dv + uv^2, \\\\\n",
    "  &u(0, x) = u_0(x), \\quad v(0, x) = v_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad v(t, -L) = v(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\epsilon_1, \\epsilon_2, b, d > 0$ are some parameters, and $[x_{\\min}, x_{\\max}]$ covers one full period.\n",
    "\n",
    "See [this link](https://www.chebfun.org/examples/pde/GrayScott.html) for a description"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "from pyDOE import lhs\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "import time\n",
    "# set the random seed\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def figsize(scale, nplots = 1):\n",
    "#     fig_width_pt = 390.0                          # Get this from LaTeX using \\the\\textwidth\n",
    "#     inches_per_pt = 1.0/72.27                       # Convert pt to inch\n",
    "#     golden_mean = (np.sqrt(5.0)-1.0)/2.0            # Aesthetic ratio (you could change this)\n",
    "#     fig_width = fig_width_pt*inches_per_pt*scale    # width in inches\n",
    "#     fig_height = nplots*fig_width*golden_mean       # height in inches\n",
    "#     fig_size = [fig_width,fig_height]\n",
    "#     return fig_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pgf_with_latex = {                      # setup matplotlib to use latex for output\n",
    "#     \"pgf.texsystem\": \"pdflatex\",        # change this if using xetex or lautex\n",
    "#     \"text.usetex\": True,                # use LaTeX to write all text\n",
    "#     \"font.family\": \"serif\",\n",
    "#     \"font.serif\": [],                   # blank entries should cause plots to inherit fonts from the document\n",
    "#     \"font.sans-serif\": [],\n",
    "#     \"font.monospace\": [],\n",
    "#     \"axes.labelsize\": 10,               # LaTeX default is 10pt font.\n",
    "#     \"font.size\": 10,\n",
    "#     \"legend.fontsize\": 8,               # Make the legend/label fonts a little smaller\n",
    "#     \"xtick.labelsize\": 8,\n",
    "#     \"ytick.labelsize\": 8,\n",
    "#     \"figure.figsize\": figsize(1.0),     # default fig size of 0.9 textwidth\n",
    "#     \"pgf.preamble\": r\"\\usepackage[utf8x]{inputenc} \\usepackage[T1]{fontenc}\"\n",
    "#     # use utf8 fonts becasue your computer can handle it :)\n",
    "#     # plots will be generated using this preamble\n",
    "#     }\n",
    "# mpl.rcParams.update(pgf_with_latex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "# MPS or CUDA or CPU\n",
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Problem Definition\n",
    "\"\"\"\n",
    "# define grid for quadrature solution\n",
    "# the following parameters and initial condition/PBC are taken from the exmaple\n",
    "# in spin('GS') from chebfun\n",
    "epsilon1 = 1\n",
    "epsilon2 = 1e-2\n",
    "b = 2e-2\n",
    "d = 8.62e-2\n",
    "L = 50.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 20.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "L_ten = torch.tensor(L).float().to(device)\n",
    "u0 = lambda x: 1.0 - 0.5 * np.power(np.sin(np.pi * (x - L)/(2.0 * L)), 4.0)\n",
    "u0_ten = lambda x: 1.0 - 0.5 * torch.pow(torch.sin(pi_ten * (x - L_ten)/(2.0 * L_ten)), 4.0)\n",
    "v0 = lambda x: 0.25 * np.power(np.sin(np.pi * (x - L)/(2.0 * L)), 4.0)\n",
    "v0_ten = lambda x: 0.25 * torch.pow(torch.sin(pi_ten * (x - L_ten)/(2.0 * L_ten)), 4.0)\n",
    "# we need to prepare the true solution from chebfun "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the physics-guided neural network\n",
    "# we will use built-in transformation to make it \n",
    "# periodic by default\n",
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, m, X_PDE, layers, epsilon1, epsilon2, b, d):\n",
    "        # Prepare the periodic layer\n",
    "        m_vec = np.expand_dims(np.arange(1, m + 1), axis = 0)\n",
    "        self.ms = torch.tensor(2.0 * np.pi/period * m_vec).float().to(device)\n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "        # layers to build Neural Net\n",
    "        # changing the input variables for periodic layer\n",
    "        # m cosines, m sines, and 1 costant, and 1 t\n",
    "        layers[0] = int(2 * m + 2)\n",
    "        self.layers = layers\n",
    "        # equation related parameters\n",
    "        self.epsilon1 = epsilon1\n",
    "        self.epsilon2 = epsilon2\n",
    "        self.b = b\n",
    "        self.d = d\n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=10000, \n",
    "            max_eval=5000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-7, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )        \n",
    "        self.iter = 0\n",
    "    # evaluater neural network\n",
    "    def NN_eval(self, t, x):  \n",
    "        x_trans = torch.matmul(x, self.ms)\n",
    "        NN = self.dnn(torch.cat([t, torch.ones_like(x), torch.cos(x_trans), torch.sin(x_trans)], dim = 1))\n",
    "        uNN = NN[:, 0][:, None]\n",
    "        vNN = NN[:, 1][:, None]\n",
    "        u0_torch = u0_ten(x)\n",
    "        v0_torch = v0_ten(x)\n",
    "        u = u0_torch + t * uNN\n",
    "        v = v0_torch + t * vNN\n",
    "        return u, v\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        v_t  = torch.autograd.grad(v,   t, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        v_x  = torch.autograd.grad(v,   x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        v_xx = torch.autograd.grad(v_x, x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        Eq1  = u_t - self.epsilon1 * u_xx - self.b * (1.0 - u) + u * torch.pow(v, 2.0)\n",
    "        Eq2  = v_t - self.epsilon2 * v_xx + self.d * v - u * torch.pow(v, 2.0)\n",
    "        return Eq1, Eq2\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "        # compute PDE loss\n",
    "        pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE1 = torch.mean(torch.square(pde1_pred))\n",
    "        loss_PDE2 = torch.mean(torch.square(pde2_pred))\n",
    "        # compute the total loss, it can be weighted\n",
    "        loss = loss_PDE1 + loss_PDE2\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            print('Iter %5d, Total: %10.4e' % (self.iter, loss.item()))\n",
    "            print('PDE(u): %10.4e, PDE(v): %10.4e' % (loss_PDE1.item(), loss_PDE2.item()))         \n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        for epoch in range(nIter):\n",
    "            # compute PDE loss\n",
    "            pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE1 = torch.mean(torch.square(pde1_pred))\n",
    "            loss_PDE2 = torch.mean(torch.square(pde2_pred))\n",
    "            # compute the total loss, it can be weighted\n",
    "            loss = loss_PDE1 + loss_PDE2\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                print('Iter %5d, Total: %10.4e' % (epoch + 1, loss.item()))\n",
    "                print('PDE(u): %10.4e, PDE(v): %10.4e' % (loss_PDE1.item(), loss_PDE2.item()))\n",
    "        # Backward and optimize\n",
    "        self.optimizer_LBFGS.step(self.loss_func)                \n",
    "    #        \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        v = v.detach().cpu().numpy()\n",
    "        return u, v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# construct the points\n",
    "def get_points_for_training(tlo, thi, xlo, xhi, N_PDE):\n",
    "    # for collocation pts, (t, x) \\in (0, T)x\\Omega\n",
    "    pts_rand = lhs(2, N_PDE)\n",
    "    t_PDE = tlo + (thi - tlo) * pts_rand[:, 0:1]\n",
    "    x_PDE = xlo + (xhi - xlo) * pts_rand[:, 1:2]\n",
    "    ptsPDE = np.hstack((t_PDE, x_PDE))\n",
    "    return ptsPDE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter  1000, Total: 8.1161e-05\n",
      "PDE(u): 5.6932e-05, PDE(v): 2.4229e-05\n",
      "Iter  2000, Total: 3.2124e-05\n",
      "PDE(u): 1.5062e-05, PDE(v): 1.7063e-05\n"
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 2]\n",
    "m = 10\n",
    "N_IC = 64\n",
    "N_PDE = int(64 * 64)\n",
    "ptsPDE = get_points_for_training(tlo, thi, xlo, xhi, N_PDE)\n",
    "model = PhysicsInformedNN(period, m, ptsPDE, layers, epsilon1, epsilon2, b, d)\n",
    "model.train(50000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply PINN to the same grid as the quadrature solution for comparison\n",
    "t = np.linspace(tlo, thi, 201)\n",
    "x = np.linspace(xlo, xhi, 1001)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred, v_pred = model.predict(pts_flat)           \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')\n",
    "v_pred = griddata(pts_flat, v_pred.flatten(), (T, X), method='cubic')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = scipy.io.loadmat('/Users/abraham/Desktop/PINN/New_Data/GS_1D_t20.mat')\n",
    "t = data['t'].flatten()[:,None]\n",
    "x2 = data['x'].flatten()[:,None]\n",
    "u_sol = np.real(data['u_sol']).T\n",
    "v_sol = np.real(data['v_sol']).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(18, 6))\n",
    "#\n",
    "ax = plt.subplot(1, 3, 1)\n",
    "ax.plot(x2, u_sol[0,:], 'bo-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 0], 'rx--', linewidth = 2, label = 'u')\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$u(t,x)$')    \n",
    "ax.set_title('$t = %.1f$' %(t[0]), fontsize = 24)\n",
    "\n",
    "\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(24)\n",
    "\n",
    "ax = plt.subplot(1, 3, 2)\n",
    "ax.plot(x,u_sol[100,:], 'bo-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 100], 'rx--', linewidth = 2, label = 'u')\n",
    "\n",
    "ax.set_xlabel('x')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax.set_title('$t = %.1f$' %(t[100]), fontsize = 24)\n",
    "ax.legend(\n",
    "    loc='upper center', \n",
    "    bbox_to_anchor=(0.5, -0.15), \n",
    "    ncol=5, \n",
    "    frameon=False, \n",
    "    prop={'size': 24}\n",
    ")\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(24)\n",
    "\n",
    "ax = plt.subplot(1, 3, 3)\n",
    "ax.plot(x,u_sol[200,:], 'bo-', linewidth = 2, label = 'Exact')       \n",
    "ax.plot(x, u_pred[:, 200], 'rx--', linewidth = 2, label = 'u')\n",
    "\n",
    "ax.set_xlabel('$x$')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "ax.set_title('$t = %.1f$' %(t[200]), fontsize = 24)\n",
    "plt.locator_params(axis = 'y', nbins = 5)\n",
    "plt.locator_params(axis = 'x', nbins = 5)\n",
    "for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +\n",
    "             ax.get_xticklabels() + ax.get_yticklabels()):\n",
    "    item.set_fontsize(24)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
