{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Cahn Hilliard PDE\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Cahn Hillard PDE with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\epsilon_1(-u_{xx}  - \\epsilon_2u_{xxxx} + (u^3)_{xx}), \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &u(0, x) = u_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\epsilon_1, \\epsilon-2 > 0$ are given constants, and $[-L, L]$ covers one full period, i.e. $T = 2L$.\n",
    "\n",
    "Due to the long time to compute forth derivative through back propagation, we will let $v = u_{xx}$, and \n",
    "$$\n",
    "(u^3)_{xx} = (3u^2u_x)_x = 6uu_x^2 + 3u^2u_{xx},\n",
    "$$\n",
    "then\n",
    "$$\n",
    "u_t = \\epsilon_1(-u_{xx}  - \\epsilon_2u_{xxxx} + (u^3)_{xx}) = \\epsilon_1(-u - \\epsilon_2u_{xx} + u^3)_{xx}\n",
    "$$\n",
    "becomes\n",
    "\\begin{aligned}\n",
    "u_t &= \\epsilon_1v_{xx}, \\\\\n",
    "v &= -(u - u^3) - \\epsilon_2u_{xx}\n",
    "\\end{aligned}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "import time\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from pyDOE import lhs\n",
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "# MPS or CUDA or CPU\n",
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon1 = 1e-2\n",
    "epsilon2 = 1e-4\n",
    "L = 1.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 1.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "u0 = lambda x: -np.cos(2.0 * np.pi * x)\n",
    "u0_ten = lambda x: -torch.cos(2.0 * pi_ten * x)\n",
    "v0 = lambda x: 4*np.pi**2*np.cos(2*np.pi*x)\n",
    "v0_ten = lambda x: 4*pi_ten**2*torch.cos(2*pi_ten * x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "        self.layers[0].weight = torch.load('initial_weights_PBC.pt')\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, m, X_PDE, layers, epsilon1, epsilon2):\n",
    "        # Prepare the periodic layer\n",
    "        m_vec = np.expand_dims(np.arange(1, m + 1), axis = 0)\n",
    "        self.ms = torch.tensor(2.0 * np.pi/period * m_vec).float().to(device)\n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "        # layers to build Neural Net\n",
    "        # changing the input variables for periodic layer\n",
    "        # m cosines, m sines, and 1 costant, and 1 t\n",
    "        layers[0] = int(2 * m + 2)\n",
    "        self.layers = layers\n",
    "        # equation related parameters\n",
    "        self.epsilon1 = epsilon1\n",
    "        self.epsilon2 = epsilon2\n",
    "        \n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        self.scheduler = lr_scheduler.ExponentialLR(self.optimizer_Adam, gamma=0.99)\n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=10000, \n",
    "            max_eval=5000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-7, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )        \n",
    "        self.iter = 0\n",
    "    # evaluater neural network\n",
    "    def NN_eval(self, t, x):  \n",
    "        x_trans = torch.matmul(x, self.ms)\n",
    "        NN = self.dnn(torch.cat([t, torch.ones_like(x), torch.cos(x_trans), torch.sin(x_trans)], dim = 1))\n",
    "        uNN = NN[:, 0][:, None]\n",
    "        vNN = NN[:, 1][:, None]\n",
    "        u0_torch = u0_ten(x)\n",
    "        \n",
    "        u = u0_torch *torch.exp(-0.1*t)+ t * uNN\n",
    "        \n",
    "        return u, vNN\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        v_x  = torch.autograd.grad(v,   x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        v_xx = torch.autograd.grad(v_x, x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        Eq1  = u_t -self.epsilon1 * v_xx\n",
    "        Eq2  = v +  (u + self.epsilon2 * u_xx - torch.pow(u, 3.0))\n",
    "        return Eq1, Eq2\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "        # compute PDE loss\n",
    "        pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE1 = torch.mean(torch.square(pde1_pred))\n",
    "        loss_PDE2 = torch.mean(torch.square(pde2_pred))\n",
    "        # compute the total loss, it can be weighted\n",
    "        loss = loss_PDE1 + loss_PDE2\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            print('Iter %5d, Total: %10.4e' % (self.iter, loss.item()))\n",
    "            print('PDE(u): %10.4e, PDE(v): %10.4e' % (loss_PDE1.item(), loss_PDE2.item()))         \n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        for epoch in range(nIter):\n",
    "            # compute PDE loss\n",
    "            pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE1 = torch.mean(torch.square(pde1_pred))\n",
    "            loss_PDE2 = torch.mean(torch.square(pde2_pred))\n",
    "            # compute the total loss, it can be weighted\n",
    "            loss = loss_PDE1 + loss_PDE2\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                print('Iter %5d, Total: %10.4e' % (epoch + 1, loss.item()))\n",
    "                print('PDE(u): %10.4e, PDE(v): %10.4e' % (loss_PDE1.item(), loss_PDE2.item()))\n",
    "        # Backward and optimize\n",
    "        self.optimizer_LBFGS.step(self.loss_func)                \n",
    "    #        \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        v = v.detach().cpu().numpy()\n",
    "        return u, v"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_points_for_training(tlo, thi, xlo, xhi, N_IC, N_BC, N_PDE):\n",
    "    # for IC, (t, x) \\in {tlo}x\\Omega\n",
    "    x_IC = np.expand_dims(np.linspace(xlo, xhi, N_IC), axis = 1)\n",
    "    t_IC = tlo * np.ones_like(x_IC)\n",
    "    ptsIC = np.hstack((t_IC, x_IC))\n",
    "    # for BC, (t, x) \\in {xlo}x(0, T]\n",
    "    t_BC = np.linspace(tlo, thi, N_BC + 1)\n",
    "    t_BC = np.expand_dims(t_BC[1:], axis = 1)\n",
    "    x_BC = xlo * np.ones_like(t_BC)\n",
    "    ptsBC = np.hstack((t_BC, x_BC))\n",
    "    # for collocation pts, (t, x) \\in (0, T)x\\Omega\n",
    "    pts_rand = lhs(2, N_PDE)\n",
    "    t_PDE = tlo + (thi - tlo) * pts_rand[:, 0:1]\n",
    "    x_PDE = xlo + (xhi - xlo) * pts_rand[:, 1:2]\n",
    "    ptsPDE = np.hstack((t_PDE, x_PDE))\n",
    "    return ptsIC, ptsBC, ptsPDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter  1000, Total: 6.2900e-03\n",
      "PDE(u): 2.3026e-04, PDE(v): 6.0597e-03\n",
      "Iter  2000, Total: 5.3654e-03\n",
      "PDE(u): 3.6619e-04, PDE(v): 4.9992e-03\n",
      "Iter  3000, Total: 1.5154e-03\n",
      "PDE(u): 1.9226e-04, PDE(v): 1.3232e-03\n",
      "Iter  4000, Total: 8.3274e-04\n",
      "PDE(u): 1.2023e-04, PDE(v): 7.1251e-04\n",
      "Iter  5000, Total: 6.7452e-04\n",
      "PDE(u): 8.1744e-05, PDE(v): 5.9278e-04\n",
      "Iter  6000, Total: 5.4883e-04\n",
      "PDE(u): 6.0386e-05, PDE(v): 4.8844e-04\n",
      "Iter  7000, Total: 4.7206e-04\n",
      "PDE(u): 6.0852e-05, PDE(v): 4.1121e-04\n",
      "Iter  8000, Total: 4.1820e-04\n",
      "PDE(u): 6.0707e-05, PDE(v): 3.5750e-04\n",
      "Iter  9000, Total: 3.6063e-04\n",
      "PDE(u): 4.7963e-05, PDE(v): 3.1267e-04\n",
      "Iter 10000, Total: 3.1055e-04\n",
      "PDE(u): 3.7818e-05, PDE(v): 2.7273e-04\n",
      "Iter 11000, Total: 2.7058e-04\n",
      "PDE(u): 3.3229e-05, PDE(v): 2.3735e-04\n",
      "Iter 12000, Total: 1.1099e-03\n",
      "PDE(u): 8.9481e-04, PDE(v): 2.1507e-04\n",
      "Iter 13000, Total: 2.0234e-04\n",
      "PDE(u): 2.2674e-05, PDE(v): 1.7966e-04\n",
      "Iter 14000, Total: 2.8272e-04\n",
      "PDE(u): 1.2428e-04, PDE(v): 1.5843e-04\n",
      "Iter 15000, Total: 1.5687e-04\n",
      "PDE(u): 1.7541e-05, PDE(v): 1.3933e-04\n",
      "Iter 16000, Total: 1.3618e-03\n",
      "PDE(u): 1.1002e-03, PDE(v): 2.6156e-04\n",
      "Iter 17000, Total: 1.3174e-04\n",
      "PDE(u): 1.6268e-05, PDE(v): 1.1547e-04\n",
      "Iter 18000, Total: 1.2089e-04\n",
      "PDE(u): 1.4115e-05, PDE(v): 1.0677e-04\n",
      "Iter 19000, Total: 1.1736e-04\n",
      "PDE(u): 1.7126e-05, PDE(v): 1.0024e-04\n",
      "Iter 20000, Total: 1.6710e-04\n",
      "PDE(u): 7.2700e-05, PDE(v): 9.4395e-05\n",
      "Iter 21000, Total: 1.0112e-04\n",
      "PDE(u): 1.2873e-05, PDE(v): 8.8248e-05\n",
      "Iter 22000, Total: 9.4568e-05\n",
      "PDE(u): 1.1532e-05, PDE(v): 8.3036e-05\n",
      "Iter 23000, Total: 3.4270e-04\n",
      "PDE(u): 2.4591e-04, PDE(v): 9.6786e-05\n",
      "Iter 24000, Total: 1.3666e-04\n",
      "PDE(u): 4.4730e-05, PDE(v): 9.1928e-05\n",
      "Iter 25000, Total: 2.2537e-04\n",
      "PDE(u): 1.5313e-04, PDE(v): 7.2235e-05\n",
      "Iter 26000, Total: 7.0358e-04\n",
      "PDE(u): 5.5182e-04, PDE(v): 1.5176e-04\n",
      "Iter  1000, Total: 3.0063e-05\n",
      "PDE(u): 7.7562e-06, PDE(v): 2.2307e-05\n",
      "Iter  2000, Total: 1.4874e-05\n",
      "PDE(u): 4.1969e-06, PDE(v): 1.0677e-05\n",
      "Iter  3000, Total: 7.2256e-06\n",
      "PDE(u): 2.6739e-06, PDE(v): 4.5517e-06\n",
      "Iter  4000, Total: 3.9511e-06\n",
      "PDE(u): 1.5456e-06, PDE(v): 2.4055e-06\n",
      "Iter  5000, Total: 2.4077e-06\n",
      "PDE(u): 1.0100e-06, PDE(v): 1.3977e-06\n"
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 32, 2]\n",
    "m = 16\n",
    "ptsIC=np.load('ptsIC.npy')\n",
    "ptsBC=np.load('ptsBC.npy')\n",
    "ptsPDE=np.load('random_data.npy')\n",
    "ptsPDE = np.vstack((ptsIC,ptsPDE))\n",
    "ptsPDE = np.vstack((ptsBC,ptsPDE))\n",
    "\n",
    "model = PhysicsInformedNN(period,m,ptsPDE,layers, epsilon1, epsilon2)\n",
    "model.train(26000)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply PINN to the same grid as the quadrature solution for comparison\n",
    "t = np.linspace(tlo, thi, 101)\n",
    "x = np.linspace(xlo, xhi, 201)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred, v_pred = model.predict(pts_flat)\n",
    "            \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = scipy.io.loadmat('Data/CH.mat')\n",
    "t = data['t'].flatten()[:,None]\n",
    "x2 = data['x'].flatten()[:,None]\n",
    "u_sol = np.real(data['Exact']).T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_error_l2(pred,exact):\n",
    "    error_l2 = np.sqrt(np.sum(np.power(pred - exact,2)))\n",
    "    relative = error_l2/np.sqrt(np.sum(np.power(exact,2)))\n",
    "    return relative\n",
    "def relative_error_l1(pred,exact):\n",
    "    error_l1 = np.sum(np.abs(pred-exact))\n",
    "    relative = error_l1/np.sum(np.abs(exact))\n",
    "    return relative\n",
    "def relative_error_linf(pred,exact):\n",
    "    error_linf = np.max(np.abs(pred-exact))\n",
    "    relative = error_linf/np.max(np.abs(exact))\n",
    "    return relative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "l2: 0.003865331184273986\n",
      "l1: 0.0019053532703783888\n",
      "linf: 0.02733061835169584\n"
     ]
    }
   ],
   "source": [
    "print(f'l2: {relative_error_l2(u_pred.T,u_sol)}')\n",
    "print(f'l1: {relative_error_l1(u_pred.T,u_sol)}')\n",
    "print(f'linf: {relative_error_linf(u_pred.T,u_sol)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
