{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Cahn Hilliard PDE\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Cahn Hillard PDE with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\epsilon_1(-u_{xx}  - \\epsilon_2u_{xxxx} + (u^3)_{xx}), \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &u(0, x) = u_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\epsilon_1, \\epsilon-2 > 0$ are given constants, and $[-L, L]$ covers one full period, i.e. $T = 2L$.\n",
    "\n",
    "Due to the long time to compute forth derivative through back propagation, we will let $v = u_{xx}$, and \n",
    "$$\n",
    "(u^3)_{xx} = (3u^2u_x)_x = 6uu_x^2 + 3u^2u_{xx},\n",
    "$$\n",
    "then\n",
    "$$\n",
    "u_t = \\epsilon_1(-u_{xx}  - \\epsilon_2u_{xxxx} + (u^3)_{xx}) = \\epsilon_1(-u - \\epsilon_2u_{xx} + u^3)_{xx}\n",
    "$$\n",
    "becomes\n",
    "\\begin{aligned}\n",
    "u_t &= \\epsilon_1v_{xx}, \\\\\n",
    "v &= -(u - u^3) - \\epsilon_2u_{xx}\n",
    "\\end{aligned}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "import time\n",
    "import numpy as np\n",
    "import scipy as sp\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from pyDOE import lhs\n",
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon1 = 1e-2\n",
    "epsilon2 = 1e-4\n",
    "L = 1.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 1.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "u0 = lambda x: -np.cos(2.0 * np.pi * x)\n",
    "u0_ten = lambda x: -torch.cos(2.0 * pi_ten * x)\n",
    "v0 = lambda x: 4*np.pi**2*np.cos(2*np.pi*x)\n",
    "v0_ten = lambda x: 4*pi_ten**2*torch.cos(2*pi_ten * x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "        self.layers[0].weight = torch.load('initial_weight.pt')\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, epsilon1, epsilon2, X_BC,X_PDE, layers):\n",
    "       \n",
    "        # BC data point\n",
    "        self.t_BC = torch.tensor(X_BC[:, 0:1]).float().to(device)\n",
    "        self.x_BC = torch.tensor(X_BC[:, 1:2]).float().to(device)\n",
    "       \n",
    "        N_BC = X_BC.shape[0]\n",
    "        self.LW_uBC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_BC, 1).float(), requires_grad=True).to(device)])        \n",
    "        self.LW_vBC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_BC, 1).float(), requires_grad=True).to(device)])        \n",
    " \n",
    "        self.period = torch.tensor(period).float().to(device)\n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "       \n",
    "        N_PDE = X_PDE.shape[0]\n",
    "        self.LW_uPDE = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_PDE, 1).float(), requires_grad=True).to(device)])\n",
    "        self.LW_vPDE = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_PDE, 1).float(), requires_grad=True).to(device)])\n",
    "\n",
    "        # equation related parameters\n",
    "        self.epsilon1 = torch.tensor(epsilon1).float().to(device)\n",
    "        self.epsilon2 = torch.tensor(epsilon2).float().to(device)\n",
    "        \n",
    "        # layers to build Neural Net\n",
    "        self.layers = layers\n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        # add a learning rate scheduler\n",
    "        \n",
    "        self.optimizer_LW_uBC = torch.optim.Adam(self.LW_uBC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_uPDE = torch.optim.Adam(self.LW_uPDE.parameters(), lr = 5e-3)\n",
    "       \n",
    "        self.optimizer_LW_vBC = torch.optim.Adam(self.LW_vBC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_vPDE = torch.optim.Adam(self.LW_vPDE.parameters(), lr = 5e-3)\n",
    "\n",
    "       \n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=10000, \n",
    "            max_eval=5000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-7, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )      \n",
    "        self.scheduler = lr_scheduler.ExponentialLR(self.optimizer_Adam, gamma=0.99)\n",
    "        self.iter = 0\n",
    "    # evaluater neural network, no transformation\n",
    "    def NN_eval(self, t, x): \n",
    "        NN = self.dnn(torch.cat([t, x], dim = 1)) \n",
    "        u = NN[:, 0][:, None]\n",
    "        v = NN[:, 1][:, None]\n",
    "        u0_torch = u0_ten(x)\n",
    "       \n",
    "        unew = u0_torch*torch.exp(-t) + t * u\n",
    "       \n",
    "        return unew,v\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        v_x  = torch.autograd.grad(v,   x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        v_xx = torch.autograd.grad(v_x, x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]\n",
    "        Eq1  = u_t -self.epsilon1 * v_xx\n",
    "        Eq2  = v +  (u + self.epsilon2 * u_xx - torch.pow(u, 3.0))\n",
    "        return Eq1, Eq2\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "\n",
    "        \n",
    "        # compute PBC loss\n",
    "        uBC_pred_left, vBC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "        uBC_pred_right, vBC_pred_right = self.NN_eval(self.t_BC, self.x_BC+self.period)\n",
    "        loss_BC = torch.mean((self.LW_uBC[0]*(uBC_pred_left - uBC_pred_right))**2.0) + torch.mean((self.LW_vBC[0]*(vBC_pred_left - vBC_pred_right))**2.0)\n",
    "        \n",
    "        \n",
    "        pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE = torch.mean((self.LW_uPDE[0]*pde1_pred )** 2) + torch.mean((self.LW_vPDE[0]*pde2_pred )** 2)   \n",
    "\n",
    "        # compute the total loss, it can be weighted\n",
    "        loss =  loss_BC + loss_PDE\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            end_time = time.time()\n",
    "            print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (self.iter, loss.item(), end_time - self.start_time))\n",
    "            print('PBC: %10.4e, PDE: %10.4e' % (loss_BC.item(), loss_PDE.item()))\n",
    "            self.start_time = end_time\n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the timer\n",
    "        start_time = time.time()        \n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        print('Starting with Adam')\n",
    "        for epoch in range(nIter):\n",
    "          \n",
    "            # compute PBC loss\n",
    "            uBC_pred_left, vBC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "            uBC_pred_right, vBC_pred_right = self.NN_eval(self.t_BC, self.x_BC+self.period)\n",
    "            loss_BC = torch.mean((self.LW_uBC[0]*(uBC_pred_left - uBC_pred_right))**2.0) + torch.mean((self.LW_vBC[0]*(vBC_pred_left - vBC_pred_right))**2.0)\n",
    "        \n",
    "        \n",
    "            pde1_pred, pde2_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE = torch.mean((self.LW_uPDE[0]*pde1_pred )** 2) + torch.mean((self.LW_vPDE[0]*pde2_pred )** 2)   \n",
    "         \n",
    "            # compute the total loss, it can be weighted\n",
    "            loss =  loss_BC + loss_PDE\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            self.optimizer_LW_uBC.zero_grad()\n",
    "            self.optimizer_LW_uPDE.zero_grad()\n",
    "            \n",
    "            self.optimizer_LW_vBC.zero_grad()\n",
    "            self.optimizer_LW_vPDE.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            \n",
    "            self.LW_uBC[0].grad.data = -self.LW_uBC[0].grad.data\n",
    "            self.LW_uPDE[0].grad.data = -self.LW_uPDE[0].grad.data\n",
    "            \n",
    "            self.LW_vBC[0].grad.data = -self.LW_vBC[0].grad.data\n",
    "            self.LW_vPDE[0].grad.data = -self.LW_vPDE[0].grad.data\n",
    "            \n",
    "            self.optimizer_LW_uBC.step()\n",
    "            self.optimizer_LW_uPDE.step()\n",
    "            \n",
    "            self.optimizer_LW_vBC.step()\n",
    "            self.optimizer_LW_vPDE.step()\n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                end_time = time.time()\n",
    "                print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (epoch + 1, loss.item(), end_time - start_time))\n",
    "                print('PBC: %10.4e, PDE: %10.4e' % (loss_BC.item(), loss_PDE.item()))\n",
    "                print('For uBC,  min uLW: %10.4e, max uLW: %10.4e' %(torch.min(self.LW_uBC[0]).item(), torch.max(self.LW_uBC[0]).item()))\n",
    "                print('For uPDE, min uLW: %10.4e, max uLW: %10.4e' %(torch.min(self.LW_uPDE[0]).item(), torch.max(self.LW_uPDE[0]).item()))\n",
    "                \n",
    "                print('For vBC,  min vLW: %10.4e, max vLW: %10.4e' %(torch.min(self.LW_vBC[0]).item(), torch.max(self.LW_vBC[0]).item()))\n",
    "                print('For vPDE, min vLW: %10.4e, max vLW: %10.4e' %(torch.min(self.LW_vPDE[0]).item(), torch.max(self.LW_vPDE[0]).item()))\n",
    "                \n",
    "                start_time = end_time\n",
    "                # change the learning rate\n",
    "                self.scheduler.step()\n",
    "                \n",
    "        print('Starting with L-BFGS')\n",
    "        self.start_time = time.time()\n",
    "        self.optimizer_LBFGS.step(self.loss_func)    \n",
    "   \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u, v = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        v = v.detach().cpu().numpy()\n",
    "        return u,v\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting with Adam\n",
      "Iter  1000, Total: 6.1610e-01, Time: 164.39 secs\n",
      "PBC: 1.9136e-05, PDE: 6.1608e-01\n",
      "For uBC,  min uLW: 1.1407e+00, max uLW: 2.3475e+00\n",
      "For uPDE, min uLW: 1.4318e+00, max uLW: 6.6447e+00\n",
      "For vBC,  min vLW: 1.5675e+00, max vLW: 2.1865e+00\n",
      "For vPDE, min vLW: 1.6240e+00, max vLW: 9.3066e+00\n",
      "Iter  2000, Total: 4.3404e-01, Time: 154.32 secs\n",
      "PBC: 5.5828e-05, PDE: 4.3399e-01\n",
      "For uBC,  min uLW: 1.1848e+00, max uLW: 3.0377e+00\n",
      "For uPDE, min uLW: 1.6744e+00, max uLW: 1.2202e+01\n",
      "For vBC,  min vLW: 2.0069e+00, max vLW: 3.6445e+00\n",
      "For vPDE, min vLW: 1.8661e+00, max vLW: 1.8458e+01\n",
      "Iter  3000, Total: 2.4083e-01, Time: 150.36 secs\n",
      "PBC: 7.3262e-05, PDE: 2.4076e-01\n",
      "For uBC,  min uLW: 1.3055e+00, max uLW: 8.6316e+00\n",
      "For uPDE, min uLW: 1.9053e+00, max uLW: 1.6120e+01\n",
      "For vBC,  min vLW: 2.1764e+00, max vLW: 5.4744e+00\n",
      "For vPDE, min vLW: 1.8797e+00, max vLW: 2.3927e+01\n",
      "Iter  4000, Total: 3.0907e-01, Time: 155.62 secs\n",
      "PBC: 1.8520e-04, PDE: 3.0888e-01\n",
      "For uBC,  min uLW: 1.4130e+00, max uLW: 1.3831e+01\n",
      "For uPDE, min uLW: 1.9551e+00, max uLW: 1.8570e+01\n",
      "For vBC,  min vLW: 2.4829e+00, max vLW: 9.0131e+00\n",
      "For vPDE, min vLW: 1.9114e+00, max vLW: 2.9166e+01\n",
      "Iter  5000, Total: 3.3899e-01, Time: 155.27 secs\n",
      "PBC: 9.2976e-05, PDE: 3.3889e-01\n",
      "For uBC,  min uLW: 1.4734e+00, max uLW: 1.7627e+01\n",
      "For uPDE, min uLW: 1.9786e+00, max uLW: 2.4959e+01\n",
      "For vBC,  min vLW: 2.7510e+00, max vLW: 1.5037e+01\n",
      "For vPDE, min vLW: 2.0133e+00, max vLW: 3.4821e+01\n",
      "Iter  6000, Total: 4.8314e-01, Time: 158.16 secs\n",
      "PBC: 1.5313e-02, PDE: 4.6783e-01\n",
      "For uBC,  min uLW: 1.5040e+00, max uLW: 2.0247e+01\n",
      "For uPDE, min uLW: 2.0814e+00, max uLW: 2.9938e+01\n",
      "For vBC,  min vLW: 4.5279e+00, max vLW: 1.9683e+01\n",
      "For vPDE, min vLW: 2.0494e+00, max vLW: 3.9754e+01\n",
      "Iter  7000, Total: 2.2798e-01, Time: 156.87 secs\n",
      "PBC: 1.1732e-02, PDE: 2.1625e-01\n",
      "For uBC,  min uLW: 1.5222e+00, max uLW: 2.2325e+01\n",
      "For uPDE, min uLW: 2.2005e+00, max uLW: 3.4310e+01\n",
      "For vBC,  min vLW: 8.7535e+00, max vLW: 2.3417e+01\n",
      "For vPDE, min vLW: 2.2529e+00, max vLW: 4.4987e+01\n",
      "Iter  8000, Total: 2.5230e-01, Time: 158.29 secs\n",
      "PBC: 4.3529e-04, PDE: 2.5186e-01\n",
      "For uBC,  min uLW: 1.5357e+00, max uLW: 2.4694e+01\n",
      "For uPDE, min uLW: 2.3355e+00, max uLW: 3.8234e+01\n",
      "For vBC,  min vLW: 1.3346e+01, max vLW: 2.7178e+01\n",
      "For vPDE, min vLW: 2.4017e+00, max vLW: 5.0044e+01\n",
      "Iter  9000, Total: 4.2347e-01, Time: 155.26 secs\n",
      "PBC: 1.2790e-02, PDE: 4.1068e-01\n",
      "For uBC,  min uLW: 1.5502e+00, max uLW: 2.8051e+01\n",
      "For uPDE, min uLW: 2.3747e+00, max uLW: 4.2376e+01\n",
      "For vBC,  min vLW: 1.7334e+01, max vLW: 3.0822e+01\n",
      "For vPDE, min vLW: 2.6765e+00, max vLW: 5.5173e+01\n",
      "Iter 10000, Total: 4.1503e-01, Time: 157.92 secs\n",
      "PBC: 4.5046e-04, PDE: 4.1458e-01\n",
      "For uBC,  min uLW: 1.5653e+00, max uLW: 3.1879e+01\n",
      "For uPDE, min uLW: 2.6408e+00, max uLW: 4.7088e+01\n",
      "For vBC,  min vLW: 2.0910e+01, max vLW: 3.4382e+01\n",
      "For vPDE, min vLW: 3.0151e+00, max vLW: 6.0731e+01\n",
      "Iter 11000, Total: 3.8862e-01, Time: 155.66 secs\n",
      "PBC: 7.3420e-03, PDE: 3.8128e-01\n",
      "For uBC,  min uLW: 1.5793e+00, max uLW: 3.5543e+01\n",
      "For uPDE, min uLW: 3.0082e+00, max uLW: 5.2040e+01\n",
      "For vBC,  min vLW: 2.4130e+01, max vLW: 3.7797e+01\n",
      "For vPDE, min vLW: 3.5615e+00, max vLW: 6.6144e+01\n",
      "Iter 12000, Total: 6.2530e-01, Time: 157.95 secs\n",
      "PBC: 2.0005e-03, PDE: 6.2330e-01\n",
      "For uBC,  min uLW: 1.5921e+00, max uLW: 3.9203e+01\n",
      "For uPDE, min uLW: 3.1761e+00, max uLW: 5.6458e+01\n",
      "For vBC,  min vLW: 2.7333e+01, max vLW: 4.1325e+01\n",
      "For vPDE, min vLW: 4.8014e+00, max vLW: 7.1472e+01\n",
      "Iter 13000, Total: 4.9336e-01, Time: 157.16 secs\n",
      "PBC: 1.1019e-03, PDE: 4.9226e-01\n",
      "For uBC,  min uLW: 1.6035e+00, max uLW: 4.2592e+01\n",
      "For uPDE, min uLW: 3.2985e+00, max uLW: 5.9731e+01\n",
      "For vBC,  min vLW: 3.0461e+01, max vLW: 4.4736e+01\n",
      "For vPDE, min vLW: 5.4812e+00, max vLW: 7.6741e+01\n",
      "Iter 14000, Total: 1.0401e+00, Time: 157.38 secs\n",
      "PBC: 5.0161e-02, PDE: 9.8995e-01\n",
      "For uBC,  min uLW: 1.6133e+00, max uLW: 4.5723e+01\n",
      "For uPDE, min uLW: 3.4819e+00, max uLW: 6.2375e+01\n",
      "For vBC,  min vLW: 3.3528e+01, max vLW: 4.7998e+01\n",
      "For vPDE, min vLW: 7.0253e+00, max vLW: 8.1914e+01\n",
      "Iter 15000, Total: 5.5228e-01, Time: 156.64 secs\n",
      "PBC: 9.9396e-04, PDE: 5.5128e-01\n",
      "For uBC,  min uLW: 1.6213e+00, max uLW: 4.8531e+01\n",
      "For uPDE, min uLW: 3.9973e+00, max uLW: 6.6276e+01\n",
      "For vBC,  min vLW: 3.6407e+01, max vLW: 5.0982e+01\n",
      "For vPDE, min vLW: 8.7293e+00, max vLW: 8.7155e+01\n",
      "Iter 16000, Total: 7.7740e-01, Time: 155.72 secs\n",
      "PBC: 5.4261e-03, PDE: 7.7197e-01\n",
      "For uBC,  min uLW: 1.6275e+00, max uLW: 5.1289e+01\n",
      "For uPDE, min uLW: 4.4984e+00, max uLW: 7.0955e+01\n",
      "For vBC,  min vLW: 3.9332e+01, max vLW: 5.3997e+01\n",
      "For vPDE, min vLW: 1.0036e+01, max vLW: 9.2340e+01\n",
      "Iter 17000, Total: 7.7281e-01, Time: 155.62 secs\n",
      "PBC: 3.7185e-03, PDE: 7.6909e-01\n",
      "For uBC,  min uLW: 1.6322e+00, max uLW: 5.3960e+01\n",
      "For uPDE, min uLW: 4.9415e+00, max uLW: 7.5585e+01\n",
      "For vBC,  min vLW: 4.2237e+01, max vLW: 5.6939e+01\n",
      "For vPDE, min vLW: 1.1324e+01, max vLW: 9.7473e+01\n",
      "Iter 18000, Total: 7.4460e-01, Time: 156.02 secs\n",
      "PBC: 3.1370e-02, PDE: 7.1323e-01\n",
      "For uBC,  min uLW: 1.6356e+00, max uLW: 5.6576e+01\n",
      "For uPDE, min uLW: 5.4343e+00, max uLW: 8.0159e+01\n",
      "For vBC,  min vLW: 4.5120e+01, max vLW: 5.9820e+01\n",
      "For vPDE, min vLW: 1.2147e+01, max vLW: 1.0256e+02\n",
      "Iter 19000, Total: 1.7852e+00, Time: 152.48 secs\n",
      "PBC: 7.0662e-02, PDE: 1.7145e+00\n",
      "For uBC,  min uLW: 1.6382e+00, max uLW: 5.9196e+01\n",
      "For uPDE, min uLW: 6.2195e+00, max uLW: 8.4742e+01\n",
      "For vBC,  min vLW: 4.8050e+01, max vLW: 6.2699e+01\n",
      "For vPDE, min vLW: 1.3500e+01, max vLW: 1.0761e+02\n",
      "Iter 20000, Total: 8.7579e-01, Time: 154.78 secs\n",
      "PBC: 5.8770e-03, PDE: 8.6992e-01\n",
      "For uBC,  min uLW: 1.6402e+00, max uLW: 6.1742e+01\n",
      "For uPDE, min uLW: 7.8222e+00, max uLW: 8.9299e+01\n",
      "For vBC,  min vLW: 5.0945e+01, max vLW: 6.5468e+01\n",
      "For vPDE, min vLW: 1.5701e+01, max vLW: 1.1264e+02\n",
      "Iter 21000, Total: 1.2766e+00, Time: 156.63 secs\n",
      "PBC: 9.9553e-02, PDE: 1.1771e+00\n",
      "For uBC,  min uLW: 1.6418e+00, max uLW: 6.4283e+01\n",
      "For uPDE, min uLW: 1.0489e+01, max uLW: 9.3760e+01\n",
      "For vBC,  min vLW: 5.3813e+01, max vLW: 6.8178e+01\n",
      "For vPDE, min vLW: 1.7740e+01, max vLW: 1.1763e+02\n",
      "Iter 22000, Total: 8.3895e-01, Time: 157.27 secs\n",
      "PBC: 7.9839e-03, PDE: 8.3096e-01\n",
      "For uBC,  min uLW: 1.6433e+00, max uLW: 6.6811e+01\n",
      "For uPDE, min uLW: 1.3406e+01, max uLW: 9.8045e+01\n",
      "For vBC,  min vLW: 5.6628e+01, max vLW: 7.0847e+01\n",
      "For vPDE, min vLW: 1.8570e+01, max vLW: 1.2261e+02\n",
      "Iter 23000, Total: 1.0647e+00, Time: 150.82 secs\n",
      "PBC: 6.0695e-03, PDE: 1.0587e+00\n",
      "For uBC,  min uLW: 1.6446e+00, max uLW: 6.9379e+01\n",
      "For uPDE, min uLW: 1.5773e+01, max uLW: 1.0215e+02\n",
      "For vBC,  min vLW: 5.9439e+01, max vLW: 7.3537e+01\n",
      "For vPDE, min vLW: 1.8942e+01, max vLW: 1.2755e+02\n",
      "Iter 24000, Total: 1.5770e+00, Time: 157.46 secs\n",
      "PBC: 5.3334e-02, PDE: 1.5236e+00\n",
      "For uBC,  min uLW: 1.6458e+00, max uLW: 7.2092e+01\n",
      "For uPDE, min uLW: 1.8615e+01, max uLW: 1.0600e+02\n",
      "For vBC,  min vLW: 6.2373e+01, max vLW: 7.6438e+01\n",
      "For vPDE, min vLW: 1.9349e+01, max vLW: 1.3246e+02\n",
      "Iter 25000, Total: 1.0856e+00, Time: 154.61 secs\n",
      "PBC: 8.8623e-03, PDE: 1.0767e+00\n",
      "For uBC,  min uLW: 1.6469e+00, max uLW: 7.4549e+01\n",
      "For uPDE, min uLW: 2.1720e+01, max uLW: 1.0971e+02\n",
      "For vBC,  min vLW: 6.4995e+01, max vLW: 7.9010e+01\n",
      "For vPDE, min vLW: 1.9919e+01, max vLW: 1.3734e+02\n",
      "Iter 26000, Total: 1.0472e+00, Time: 159.41 secs\n",
      "PBC: 2.0414e-03, PDE: 1.0452e+00\n",
      "For uBC,  min uLW: 1.6478e+00, max uLW: 7.8191e+01\n",
      "For uPDE, min uLW: 2.4807e+01, max uLW: 1.1345e+02\n",
      "For vBC,  min vLW: 6.7553e+01, max vLW: 8.1553e+01\n",
      "For vPDE, min vLW: 2.0529e+01, max vLW: 1.4219e+02\n",
      "Iter 27000, Total: 9.6633e-01, Time: 156.50 secs\n",
      "PBC: 1.5260e-03, PDE: 9.6480e-01\n",
      "For uBC,  min uLW: 1.6486e+00, max uLW: 8.2209e+01\n",
      "For uPDE, min uLW: 2.7489e+01, max uLW: 1.1859e+02\n",
      "For vBC,  min vLW: 7.0120e+01, max vLW: 8.4174e+01\n",
      "For vPDE, min vLW: 2.1235e+01, max vLW: 1.4700e+02\n",
      "Iter 28000, Total: 1.0869e+00, Time: 159.41 secs\n",
      "PBC: 7.5122e-02, PDE: 1.0118e+00\n",
      "For uBC,  min uLW: 1.6493e+00, max uLW: 8.6331e+01\n",
      "For uPDE, min uLW: 2.9660e+01, max uLW: 1.2363e+02\n",
      "For vBC,  min vLW: 7.2498e+01, max vLW: 8.6599e+01\n",
      "For vPDE, min vLW: 2.1999e+01, max vLW: 1.5181e+02\n",
      "Iter 29000, Total: 1.3517e+00, Time: 158.76 secs\n",
      "PBC: 8.5597e-02, PDE: 1.2661e+00\n",
      "For uBC,  min uLW: 1.6500e+00, max uLW: 9.0541e+01\n",
      "For uPDE, min uLW: 3.1486e+01, max uLW: 1.2856e+02\n",
      "For vBC,  min vLW: 7.4862e+01, max vLW: 8.9052e+01\n",
      "For vPDE, min vLW: 2.2835e+01, max vLW: 1.5661e+02\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 30000, Total: 9.9649e-01, Time: 154.55 secs\n",
      "PBC: 3.4590e-03, PDE: 9.9303e-01\n",
      "For uBC,  min uLW: 1.6506e+00, max uLW: 9.4812e+01\n",
      "For uPDE, min uLW: 3.3390e+01, max uLW: 1.3339e+02\n",
      "For vBC,  min vLW: 7.7353e+01, max vLW: 9.1675e+01\n",
      "For vPDE, min vLW: 2.3770e+01, max vLW: 1.6161e+02\n",
      "Iter 31000, Total: 2.0576e+00, Time: 159.53 secs\n",
      "PBC: 9.7590e-02, PDE: 1.9600e+00\n",
      "For uBC,  min uLW: 1.6512e+00, max uLW: 9.9198e+01\n",
      "For uPDE, min uLW: 3.5369e+01, max uLW: 1.3816e+02\n",
      "For vBC,  min vLW: 7.9867e+01, max vLW: 9.4343e+01\n",
      "For vPDE, min vLW: 2.4764e+01, max vLW: 1.6674e+02\n",
      "Iter 32000, Total: 1.1159e+00, Time: 157.39 secs\n",
      "PBC: 1.7907e-02, PDE: 1.0980e+00\n",
      "For uBC,  min uLW: 1.6517e+00, max uLW: 1.0368e+02\n",
      "For uPDE, min uLW: 3.7165e+01, max uLW: 1.4289e+02\n",
      "For vBC,  min vLW: 8.2393e+01, max vLW: 9.7041e+01\n",
      "For vPDE, min vLW: 2.5837e+01, max vLW: 1.7187e+02\n",
      "Iter 33000, Total: 1.1637e+00, Time: 154.86 secs\n",
      "PBC: 4.1070e-03, PDE: 1.1596e+00\n",
      "For uBC,  min uLW: 1.6523e+00, max uLW: 1.0816e+02\n",
      "For uPDE, min uLW: 3.9156e+01, max uLW: 1.4759e+02\n",
      "For vBC,  min vLW: 8.4836e+01, max vLW: 9.9605e+01\n",
      "For vPDE, min vLW: 2.6945e+01, max vLW: 1.7700e+02\n",
      "Iter 34000, Total: 3.3073e+00, Time: 156.95 secs\n",
      "PBC: 1.3003e-02, PDE: 3.2943e+00\n",
      "For uBC,  min uLW: 1.6528e+00, max uLW: 1.1270e+02\n",
      "For uPDE, min uLW: 4.1204e+01, max uLW: 1.5228e+02\n",
      "For vBC,  min vLW: 8.7249e+01, max vLW: 1.0217e+02\n",
      "For vPDE, min vLW: 2.8108e+01, max vLW: 1.8212e+02\n",
      "Iter 35000, Total: 1.0537e+00, Time: 154.36 secs\n",
      "PBC: 1.6860e-03, PDE: 1.0520e+00\n",
      "For uBC,  min uLW: 1.6532e+00, max uLW: 1.1724e+02\n",
      "For uPDE, min uLW: 4.2989e+01, max uLW: 1.5699e+02\n",
      "For vBC,  min vLW: 8.9760e+01, max vLW: 1.0481e+02\n",
      "For vPDE, min vLW: 2.9260e+01, max vLW: 1.8722e+02\n",
      "Iter 36000, Total: 1.1540e+00, Time: 155.85 secs\n",
      "PBC: 2.2952e-02, PDE: 1.1311e+00\n",
      "For uBC,  min uLW: 1.6536e+00, max uLW: 1.2179e+02\n",
      "For uPDE, min uLW: 4.5060e+01, max uLW: 1.6172e+02\n",
      "For vBC,  min vLW: 9.2289e+01, max vLW: 1.0749e+02\n",
      "For vPDE, min vLW: 3.0435e+01, max vLW: 1.9230e+02\n",
      "Iter 37000, Total: 1.1756e+00, Time: 157.04 secs\n",
      "PBC: 1.0842e-02, PDE: 1.1647e+00\n",
      "For uBC,  min uLW: 1.6540e+00, max uLW: 1.2636e+02\n",
      "For uPDE, min uLW: 4.7398e+01, max uLW: 1.6648e+02\n",
      "For vBC,  min vLW: 9.4802e+01, max vLW: 1.1015e+02\n",
      "For vPDE, min vLW: 3.1587e+01, max vLW: 1.9734e+02\n",
      "Iter 38000, Total: 1.8079e+00, Time: 155.86 secs\n",
      "PBC: 5.2197e-02, PDE: 1.7557e+00\n",
      "For uBC,  min uLW: 1.6543e+00, max uLW: 1.3085e+02\n",
      "For uPDE, min uLW: 4.9906e+01, max uLW: 1.7124e+02\n",
      "For vBC,  min vLW: 9.7484e+01, max vLW: 1.1297e+02\n",
      "For vPDE, min vLW: 3.2720e+01, max vLW: 2.0235e+02\n",
      "Iter 39000, Total: 1.3518e+00, Time: 158.39 secs\n",
      "PBC: 1.4943e-01, PDE: 1.2024e+00\n",
      "For uBC,  min uLW: 1.6546e+00, max uLW: 1.3531e+02\n",
      "For uPDE, min uLW: 5.2838e+01, max uLW: 1.7602e+02\n",
      "For vBC,  min vLW: 1.0023e+02, max vLW: 1.1581e+02\n",
      "For vPDE, min vLW: 3.3801e+01, max vLW: 2.0732e+02\n",
      "Iter 40000, Total: 1.3617e+00, Time: 158.42 secs\n",
      "PBC: 3.1767e-03, PDE: 1.3585e+00\n",
      "For uBC,  min uLW: 1.6548e+00, max uLW: 1.3974e+02\n",
      "For uPDE, min uLW: 5.5888e+01, max uLW: 1.8081e+02\n",
      "For vBC,  min vLW: 1.0287e+02, max vLW: 1.1856e+02\n",
      "For vPDE, min vLW: 3.4845e+01, max vLW: 2.1225e+02\n",
      "Iter 41000, Total: 1.5469e+00, Time: 157.18 secs\n",
      "PBC: 1.1517e-01, PDE: 1.4317e+00\n",
      "For uBC,  min uLW: 1.6550e+00, max uLW: 1.4413e+02\n",
      "For uPDE, min uLW: 5.9095e+01, max uLW: 1.8562e+02\n",
      "For vBC,  min vLW: 1.0550e+02, max vLW: 1.2129e+02\n",
      "For vPDE, min vLW: 3.5812e+01, max vLW: 2.1714e+02\n",
      "Iter 42000, Total: 1.2003e+00, Time: 154.37 secs\n",
      "PBC: 3.6794e-02, PDE: 1.1635e+00\n",
      "For uBC,  min uLW: 1.6552e+00, max uLW: 1.4849e+02\n",
      "For uPDE, min uLW: 6.2061e+01, max uLW: 1.9042e+02\n",
      "For vBC,  min vLW: 1.0817e+02, max vLW: 1.2403e+02\n",
      "For vPDE, min vLW: 3.6726e+01, max vLW: 2.2199e+02\n",
      "Iter 43000, Total: 1.3048e+00, Time: 157.06 secs\n",
      "PBC: 1.7840e-02, PDE: 1.2870e+00\n",
      "For uBC,  min uLW: 1.6553e+00, max uLW: 1.5285e+02\n",
      "For uPDE, min uLW: 6.4933e+01, max uLW: 1.9523e+02\n",
      "For vBC,  min vLW: 1.1046e+02, max vLW: 1.2637e+02\n",
      "For vPDE, min vLW: 3.7595e+01, max vLW: 2.2680e+02\n",
      "Iter 44000, Total: 2.1468e+00, Time: 160.01 secs\n",
      "PBC: 1.6842e-01, PDE: 1.9784e+00\n",
      "For uBC,  min uLW: 1.6554e+00, max uLW: 1.5720e+02\n",
      "For uPDE, min uLW: 6.7522e+01, max uLW: 2.0003e+02\n",
      "For vBC,  min vLW: 1.1274e+02, max vLW: 1.2871e+02\n",
      "For vPDE, min vLW: 3.8378e+01, max vLW: 2.3158e+02\n",
      "Iter 45000, Total: 1.3102e+00, Time: 154.37 secs\n",
      "PBC: 5.2837e-02, PDE: 1.2574e+00\n",
      "For uBC,  min uLW: 1.6556e+00, max uLW: 1.6157e+02\n",
      "For uPDE, min uLW: 7.0471e+01, max uLW: 2.0483e+02\n",
      "For vBC,  min vLW: 1.1520e+02, max vLW: 1.3120e+02\n",
      "For vPDE, min vLW: 3.9118e+01, max vLW: 2.3632e+02\n",
      "Iter 46000, Total: 1.6586e+00, Time: 157.72 secs\n",
      "PBC: 2.4125e-01, PDE: 1.4174e+00\n",
      "For uBC,  min uLW: 1.6557e+00, max uLW: 1.6598e+02\n",
      "For uPDE, min uLW: 7.3291e+01, max uLW: 2.0963e+02\n",
      "For vBC,  min vLW: 1.1759e+02, max vLW: 1.3362e+02\n",
      "For vPDE, min vLW: 3.9864e+01, max vLW: 2.4104e+02\n",
      "Iter 47000, Total: 2.5659e+00, Time: 155.15 secs\n",
      "PBC: 5.8143e-02, PDE: 2.5078e+00\n",
      "For uBC,  min uLW: 1.6558e+00, max uLW: 1.7039e+02\n",
      "For uPDE, min uLW: 7.6178e+01, max uLW: 2.1442e+02\n",
      "For vBC,  min vLW: 1.2005e+02, max vLW: 1.3605e+02\n",
      "For vPDE, min vLW: 4.0569e+01, max vLW: 2.4574e+02\n",
      "Iter 48000, Total: 1.5031e+00, Time: 155.37 secs\n",
      "PBC: 1.6158e-01, PDE: 1.3416e+00\n",
      "For uBC,  min uLW: 1.6559e+00, max uLW: 1.7484e+02\n",
      "For uPDE, min uLW: 7.9176e+01, max uLW: 2.1921e+02\n",
      "For vBC,  min vLW: 1.2261e+02, max vLW: 1.3864e+02\n",
      "For vPDE, min vLW: 4.1298e+01, max vLW: 2.5041e+02\n",
      "Iter 49000, Total: 1.4544e+00, Time: 157.13 secs\n",
      "PBC: 2.1847e-03, PDE: 1.4522e+00\n",
      "For uBC,  min uLW: 1.6559e+00, max uLW: 1.7930e+02\n",
      "For uPDE, min uLW: 8.2217e+01, max uLW: 2.2399e+02\n",
      "For vBC,  min vLW: 1.2551e+02, max vLW: 1.4157e+02\n",
      "For vPDE, min vLW: 4.2044e+01, max vLW: 2.5507e+02\n",
      "Iter 50000, Total: 1.2084e+00, Time: 158.11 secs\n",
      "PBC: 1.9009e-02, PDE: 1.1894e+00\n",
      "For uBC,  min uLW: 1.6559e+00, max uLW: 1.8377e+02\n",
      "For uPDE, min uLW: 8.5127e+01, max uLW: 2.2876e+02\n",
      "For vBC,  min vLW: 1.2821e+02, max vLW: 1.4423e+02\n",
      "For vPDE, min vLW: 4.2869e+01, max vLW: 2.5971e+02\n",
      "Starting with L-BFGS\n",
      "Iter  1000, Total: 5.4497e-01, Time: 281.56 secs\n",
      "PBC: 3.9526e-04, PDE: 5.4458e-01\n",
      "Iter  2000, Total: 3.0925e-01, Time: 288.39 secs\n",
      "PBC: 4.5934e-04, PDE: 3.0879e-01\n",
      "Iter  3000, Total: 1.9407e-01, Time: 283.89 secs\n",
      "PBC: 1.8968e-04, PDE: 1.9388e-01\n",
      "Iter  4000, Total: 1.4487e-01, Time: 277.69 secs\n",
      "PBC: 6.1713e-05, PDE: 1.4481e-01\n",
      "Iter  5000, Total: 1.1766e-01, Time: 248.73 secs\n",
      "PBC: 1.5377e-04, PDE: 1.1751e-01\n"
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 32, 2]\n",
    "N_IC = 128\n",
    "ptsIC=np.load('ptsIC.npy')\n",
    "ptsBC=np.load('ptsBC.npy')\n",
    "ptsPDE=np.load('random_data.npy')\n",
    "uBC = np.load('uBC.npy')\n",
    "vBC = np.load('vBC.npy')\n",
    "x_IC = np.expand_dims(np.linspace(xlo, xhi, N_IC), axis = 1)\n",
    "uIC = u0(x_IC)\n",
    "vIC = v0(x_IC)\n",
    "model = PhysicsInformedNN(period, epsilon1, epsilon2,  ptsBC,  ptsPDE, layers)\n",
    "model.train(50000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply PINN to the same grid as the quadrature solution for comparison\n",
    "t = np.linspace(tlo, thi, 101)\n",
    "x = np.linspace(xlo, xhi, 201)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred, v_pred = model.predict(pts_flat)\n",
    "            \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = scipy.io.loadmat('Data/CH.mat')\n",
    "t = data['t'].flatten()[:,None]\n",
    "x2 = data['x'].flatten()[:,None]\n",
    "u_sol = np.real(data['Exact']).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_error_l2(pred,exact):\n",
    "    error_l2 = np.sqrt(np.sum(np.power(pred - exact,2)))\n",
    "    relative = error_l2/np.sqrt(np.sum(np.power(exact,2)))\n",
    "    return relative\n",
    "def relative_error_l1(pred,exact):\n",
    "    error_l1 = np.sum(np.abs(pred-exact))\n",
    "    relative = error_l1/np.sum(np.abs(exact))\n",
    "    return relative\n",
    "def relative_error_linf(pred,exact):\n",
    "    error_linf = np.max(np.abs(pred-exact))\n",
    "    relative = error_linf/np.max(np.abs(exact))\n",
    "    return relative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "l2: 0.023662233401978608\n",
      "l1: 0.010855790978752541\n",
      "linf: 0.0888843749630409\n"
     ]
    }
   ],
   "source": [
    "print(f'l2: {relative_error_l2(u_pred.T,u_sol)}')\n",
    "print(f'l1: {relative_error_l1(u_pred.T,u_sol)}')\n",
    "print(f'linf: {relative_error_linf(u_pred.T,u_sol)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
