{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PINN Solution of the Allen-Cahn Equations\n",
    "\n",
    "This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Allen-Cahn Equations with periodic boundary condition\n",
    "\\begin{aligned}\n",
    "  &u_t = \\gamma_1 u_{xx} - \\gamma_2 u^3+\\gamma_2 u, \\quad (t, x) \\in [0, T]\\times[-L, L]\\\\\n",
    "  &u(0, x) = u_0(x), \\quad \\forall x \\in [-L, L] \\\\\n",
    "  &u(t, -L) = u(t, L), \\quad u_x(t, -L) = u_x(t, L), \\quad \\forall t \\in [0, T]\n",
    "\\end{aligned}\n",
    "where $\\gamma_1, \\gamma_2> 0$ are some parameters, and $[x_{\\min}, x_{\\max}]$ covers one full period.\n",
    "\n",
    "Here, we let $u_0(x) = x^2\\sin(2\\pi x)$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Libraries and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.optim\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from itertools import chain\n",
    "from collections import OrderedDict\n",
    "from pyDOE import lhs\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io\n",
    "from scipy.interpolate import griddata\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib.gridspec as gridspec\n",
    "import time\n",
    "# set the random seed\n",
    "np.random.seed(1234)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on mps\n"
     ]
    }
   ],
   "source": [
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "#\n",
    "print(f\"Working on {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "epsilon = 1e-4\n",
    "gamma = 5\n",
    "L = 1.0\n",
    "xlo = -L\n",
    "xhi = L\n",
    "period = xhi - xlo\n",
    "tlo = 0.0\n",
    "thi = 1.0\n",
    "pi_ten = torch.tensor(np.pi).float().to(device)\n",
    "L_ten = torch.tensor(L).float().to(device)\n",
    "u0 = lambda x: np.power(x, 2.0) * np.cos(np.pi * x)\n",
    "u0_ten = lambda x: torch.pow(x, 2.0) * torch.cos(pi_ten * x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Physics-informed Neural Networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the deep neural network\n",
    "class DNN(torch.nn.Module):\n",
    "    def __init__(self, layers):\n",
    "        super(DNN, self).__init__()\n",
    "        # parameters\n",
    "        self.depth = len(layers) - 1\n",
    "        # set up layer order dict\n",
    "        self.activation = torch.nn.Tanh\n",
    "        layer_list = list()\n",
    "        for i in range(self.depth - 1): \n",
    "            layer_list.append(\n",
    "                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))\n",
    "            )\n",
    "            layer_list.append(('activation_%d' % i, self.activation()))\n",
    "            \n",
    "        layer_list.append(\n",
    "            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))\n",
    "        )\n",
    "        layerDict = OrderedDict(layer_list)\n",
    "        # deploy layers\n",
    "        self.layers = torch.nn.Sequential(layerDict)\n",
    "        self.layers[0].weight = torch.load('initial_weight.pt')\n",
    "    def forward(self, x):\n",
    "        out = self.layers(x)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PhysicsInformedNN():\n",
    "    def __init__(self, period, epsilon, gamma, X_IC, u_IC, X_BC, X_PDE, layers):\n",
    "        # IC data point\n",
    "        self.t_IC = torch.tensor(X_IC[:, 0:1]).float().to(device)\n",
    "        self.x_IC = torch.tensor(X_IC[:, 1:2]).float().to(device)\n",
    "        self.u_IC = torch.tensor(u_IC).float().to(device)\n",
    "        N_IC = X_IC.shape[0]\n",
    "        self.LW_IC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_IC, 1).float(), requires_grad=True).to(device)])\n",
    "        # BC data point\n",
    "        self.t_BC = torch.tensor(X_BC[:, 0:1]).float().to(device)\n",
    "        self.x_BC = torch.tensor(X_BC[:, 1:2]).float().to(device)\n",
    "        self.period = torch.tensor(period).float().to(device)\n",
    "        N_BC = X_BC.shape[0]\n",
    "        self.LW_BC = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_BC, 1).float(), requires_grad=True).to(device)])        \n",
    "        # PDE data, gradients will be computed on these points so requires_grad = True\n",
    "        self.t_PDE = torch.tensor(X_PDE[:, 0:1], requires_grad=True).float().to(device)\n",
    "        self.x_PDE = torch.tensor(X_PDE[:, 1:2], requires_grad=True).float().to(device)\n",
    "        N_PDE = X_PDE.shape[0]\n",
    "        self.LW_PDE = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(N_PDE, 1).float(), requires_grad=True).to(device)])\n",
    "        # equation related parameters\n",
    "        self.epsilon = torch.tensor(epsilon).float().to(device)\n",
    "        self.gamma = torch.tensor(gamma).float().to(device)\n",
    "        # layers to build Neural Net\n",
    "        self.layers = layers\n",
    "        # deep neural networks\n",
    "        self.dnn = DNN(layers).to(device)    \n",
    "        # prepare the optimizer\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        self.optimizer_LW_IC = torch.optim.Adam(self.LW_IC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_BC = torch.optim.Adam(self.LW_BC.parameters(), lr = 5e-3)\n",
    "        self.optimizer_LW_PDE = torch.optim.Adam(self.LW_PDE.parameters(), lr = 5e-3)\n",
    "        # add a learning rate scheduler\n",
    "        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)\n",
    "        self.scheduler = lr_scheduler.ExponentialLR(self.optimizer_Adam, gamma=0.999)\n",
    "        self.optimizer_LBFGS = torch.optim.LBFGS(\n",
    "            self.dnn.parameters(), \n",
    "            lr=1.0, \n",
    "            max_iter=50000, \n",
    "            max_eval=50000, \n",
    "            history_size=50,\n",
    "            tolerance_grad=1e-8, \n",
    "            tolerance_change=1.0 * np.finfo(float).eps,\n",
    "            line_search_fn=\"strong_wolfe\"       # can be \"strong_wolfe\"\n",
    "        )        \n",
    "        self.iter = 0\n",
    "    # evaluater neural network, no transformation\n",
    "    def NN_eval(self, t, x):  \n",
    "        return self.dnn(torch.cat([t, x], dim = 1))\n",
    "    # compute the PDE\n",
    "    def pde_eval(self, t, x):\n",
    "        \"\"\" The pytorch autograd version of calculating residual \"\"\"\n",
    "        u = self.NN_eval(t, x)\n",
    "        # compute the derivatives for u\n",
    "        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]\n",
    "        Eq  = u_t - self.epsilon * u_xx - self.gamma * (u - torch.pow(u, 3.0))       \n",
    "        return Eq\n",
    "    # compute the total loss for the second-order optimizer\n",
    "    def loss_func(self):\n",
    "        # reset the gradient\n",
    "        self.optimizer_LBFGS.zero_grad()\n",
    "        # compute IC loss\n",
    "        IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "        loss_IC = torch.mean(torch.square(self.LW_IC[0] * (IC_pred - self.u_IC)))\n",
    "        # compute PBC loss\n",
    "        BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "        BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "        loss_BC = torch.mean(torch.square(self.LW_BC[0] * (BC_pred_left - BC_pred_right)))\n",
    "        # compute PDE loss\n",
    "        pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "        loss_PDE = torch.mean(torch.square(self.LW_PDE[0] * pde_pred))    \n",
    "        # compute the total loss, it can be weighted\n",
    "        loss = loss_IC + loss_BC + loss_PDE\n",
    "        # backward propagation\n",
    "        loss.backward()\n",
    "        # increase the iteration counter\n",
    "        self.iter += 1\n",
    "        # output\n",
    "        # output the progress\n",
    "        if self.iter % 1000 == 0:\n",
    "            end_time = time.time()\n",
    "            print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (self.iter, loss.item(), end_time - self.start_time))\n",
    "            print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "            self.start_time = end_time\n",
    "        return loss\n",
    "    #\n",
    "    def train(self, nIter):\n",
    "        # start the timer\n",
    "        start_time = time.time()        \n",
    "        # start the training with Adam first\n",
    "        self.dnn.train()\n",
    "        print('Starting with Adam')\n",
    "        for epoch in range(nIter):\n",
    "            # compute IC loss\n",
    "            IC_pred = self.NN_eval(self.t_IC, self.x_IC)\n",
    "            loss_IC = torch.mean(torch.square(self.LW_IC[0] * (IC_pred - self.u_IC)))\n",
    "            # compute PBC loss\n",
    "            BC_pred_left = self.NN_eval(self.t_BC, self.x_BC)\n",
    "            BC_pred_right = self.NN_eval(self.t_BC, self.x_BC + self.period)\n",
    "            loss_BC = torch.mean(torch.square(self.LW_BC[0] * (BC_pred_left - BC_pred_right)))          \n",
    "            # compute PDE loss\n",
    "            pde_pred = self.pde_eval(self.t_PDE, self.x_PDE)\n",
    "            loss_PDE = torch.mean(torch.square(self.LW_PDE[0] * pde_pred))   \n",
    "            # compute the total loss, it can be weighted\n",
    "            loss = loss_IC + loss_BC + loss_PDE\n",
    "            # Backward and optimize\n",
    "            self.optimizer_Adam.zero_grad()\n",
    "            self.optimizer_LW_IC.zero_grad()\n",
    "            self.optimizer_LW_BC.zero_grad()\n",
    "            self.optimizer_LW_PDE.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer_Adam.step() \n",
    "            self.LW_IC[0].grad.data = -self.LW_IC[0].grad.data\n",
    "            self.LW_BC[0].grad.data = -self.LW_BC[0].grad.data\n",
    "            self.LW_PDE[0].grad.data = -self.LW_PDE[0].grad.data\n",
    "            self.optimizer_LW_IC.step()\n",
    "            self.optimizer_LW_BC.step()\n",
    "            self.optimizer_LW_PDE.step()\n",
    "            # output the progress\n",
    "            if (epoch + 1) % 1000 == 0:\n",
    "                end_time = time.time()\n",
    "                print('Iter %5d, Total: %10.4e, Time: %.2f secs' % (epoch + 1, loss.item(), end_time - start_time))\n",
    "                print('IC: %10.4e, PBC: %10.4e, PDE: %10.4e' % (loss_IC.item(), loss_BC.item(), loss_PDE.item()))\n",
    "                print('For IC,  min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_IC[0]).item(), torch.max(self.LW_IC[0]).item()))\n",
    "                print('For BC,  min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_BC[0]).item(), torch.max(self.LW_BC[0]).item()))\n",
    "                print('For PDE, min LW: %10.4e, max LW: %10.4e' %(torch.min(self.LW_PDE[0]).item(), torch.max(self.LW_PDE[0]).item()))\n",
    "                start_time = end_time\n",
    "                # change the learning rate\n",
    "                self.scheduler.step()\n",
    "        # Using the second-order L-BFGS optimizer\n",
    "        print('Starting with L-BFGS')\n",
    "        self.start_time = time.time()\n",
    "        self.optimizer_LBFGS.step(self.loss_func)                \n",
    "    #        \n",
    "    def predict(self, X):\n",
    "        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)\n",
    "        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)\n",
    "        self.dnn.eval()\n",
    "        u = self.NN_eval(t, x)\n",
    "        u = u.detach().cpu().numpy()\n",
    "        return u\n",
    "\n",
    "## Configurations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting with Adam\n",
      "Iter  1000, Total: 1.7222e+01, Time: 44.56 secs\n",
      "IC: 1.5719e+01, PBC: 5.0127e-05, PDE: 1.5033e+00\n",
      "For IC,  min LW: 4.8393e+00, max LW: 1.1367e+01\n",
      "For BC,  min LW: 4.0518e+00, max LW: 6.4607e+00\n",
      "For PDE, min LW: 3.9992e+00, max LW: 1.1988e+01\n",
      "Iter  2000, Total: 6.1371e+01, Time: 37.38 secs\n",
      "IC: 5.5676e+01, PBC: 2.5468e-04, PDE: 5.6947e+00\n",
      "For IC,  min LW: 8.1654e+00, max LW: 2.1265e+01\n",
      "For BC,  min LW: 8.3039e+00, max LW: 1.2180e+01\n",
      "For PDE, min LW: 7.5252e+00, max LW: 2.2099e+01\n",
      "Iter  3000, Total: 1.2232e+02, Time: 37.41 secs\n",
      "IC: 1.1112e+02, PBC: 2.3188e-03, PDE: 1.1205e+01\n",
      "For IC,  min LW: 1.0031e+01, max LW: 2.9783e+01\n",
      "For BC,  min LW: 1.1048e+01, max LW: 1.6000e+01\n",
      "For PDE, min LW: 9.3318e+00, max LW: 3.0774e+01\n",
      "Iter  4000, Total: 1.9458e+02, Time: 37.32 secs\n",
      "IC: 1.7584e+02, PBC: 1.3098e-03, PDE: 1.8729e+01\n",
      "For IC,  min LW: 1.0710e+01, max LW: 3.7699e+01\n",
      "For BC,  min LW: 1.1826e+01, max LW: 1.7089e+01\n",
      "For PDE, min LW: 9.6079e+00, max LW: 3.9429e+01\n",
      "Iter  5000, Total: 5.9582e+00, Time: 37.56 secs\n",
      "IC: 3.9961e+00, PBC: 6.2653e-03, PDE: 1.9559e+00\n",
      "For IC,  min LW: 1.2519e+01, max LW: 4.2393e+01\n",
      "For BC,  min LW: 1.3018e+01, max LW: 2.0758e+01\n",
      "For PDE, min LW: 1.0710e+01, max LW: 4.3362e+01\n",
      "Iter  6000, Total: 3.7959e+00, Time: 37.48 secs\n",
      "IC: 2.8777e+00, PBC: 7.4758e-03, PDE: 9.1078e-01\n",
      "For IC,  min LW: 1.2519e+01, max LW: 4.2455e+01\n",
      "For BC,  min LW: 1.3020e+01, max LW: 2.0771e+01\n",
      "For PDE, min LW: 1.0732e+01, max LW: 4.3362e+01\n",
      "Iter  7000, Total: 3.1956e+00, Time: 37.04 secs\n",
      "IC: 2.6400e+00, PBC: 4.6410e-03, PDE: 5.5097e-01\n",
      "For IC,  min LW: 1.2519e+01, max LW: 4.2491e+01\n",
      "For BC,  min LW: 1.3022e+01, max LW: 2.0788e+01\n",
      "For PDE, min LW: 1.0740e+01, max LW: 4.3362e+01\n",
      "Iter  8000, Total: 2.8354e+00, Time: 37.11 secs\n",
      "IC: 2.4245e+00, PBC: 2.7289e-03, PDE: 4.0822e-01\n",
      "For IC,  min LW: 1.2519e+01, max LW: 4.2576e+01\n",
      "For BC,  min LW: 1.3023e+01, max LW: 2.0805e+01\n",
      "For PDE, min LW: 1.0743e+01, max LW: 4.3362e+01\n",
      "Iter  9000, Total: 2.4936e+00, Time: 36.96 secs\n",
      "IC: 2.1518e+00, PBC: 1.2511e-03, PDE: 3.4051e-01\n",
      "For IC,  min LW: 1.2519e+01, max LW: 4.2783e+01\n",
      "For BC,  min LW: 1.3023e+01, max LW: 2.0818e+01\n",
      "For PDE, min LW: 1.0749e+01, max LW: 4.3362e+01\n",
      "Iter 10000, Total: 2.1575e+00, Time: 37.00 secs\n",
      "IC: 1.8496e+00, PBC: 6.8184e-04, PDE: 3.0728e-01\n",
      "For IC,  min LW: 1.2520e+01, max LW: 4.3032e+01\n",
      "For BC,  min LW: 1.3023e+01, max LW: 2.0827e+01\n",
      "For PDE, min LW: 1.0758e+01, max LW: 4.3362e+01\n",
      "Iter 11000, Total: 1.8853e+00, Time: 36.51 secs\n",
      "IC: 1.5597e+00, PBC: 8.3960e-04, PDE: 3.2472e-01\n",
      "For IC,  min LW: 1.2522e+01, max LW: 4.3306e+01\n",
      "For BC,  min LW: 1.3023e+01, max LW: 2.0836e+01\n",
      "For PDE, min LW: 1.0759e+01, max LW: 4.3362e+01\n",
      "Iter 12000, Total: 1.7144e+00, Time: 35.40 secs\n",
      "IC: 1.3774e+00, PBC: 2.4062e-03, PDE: 3.3455e-01\n",
      "For IC,  min LW: 1.2528e+01, max LW: 4.3620e+01\n",
      "For BC,  min LW: 1.3024e+01, max LW: 2.0865e+01\n",
      "For PDE, min LW: 1.0766e+01, max LW: 4.3362e+01\n",
      "Iter 13000, Total: 1.5743e+00, Time: 35.38 secs\n",
      "IC: 1.2566e+00, PBC: 5.8019e-03, PDE: 3.1190e-01\n",
      "For IC,  min LW: 1.2537e+01, max LW: 4.4051e+01\n",
      "For BC,  min LW: 1.3025e+01, max LW: 2.0992e+01\n",
      "For PDE, min LW: 1.0769e+01, max LW: 4.3362e+01\n",
      "Iter 14000, Total: 1.4446e+00, Time: 35.41 secs\n",
      "IC: 1.1496e+00, PBC: 9.6509e-03, PDE: 2.8531e-01\n",
      "For IC,  min LW: 1.2549e+01, max LW: 4.4691e+01\n",
      "For BC,  min LW: 1.3026e+01, max LW: 2.1341e+01\n",
      "For PDE, min LW: 1.0776e+01, max LW: 4.3362e+01\n",
      "Iter 15000, Total: 1.3480e+00, Time: 35.86 secs\n",
      "IC: 1.0728e+00, PBC: 1.3951e-02, PDE: 2.6122e-01\n",
      "For IC,  min LW: 1.2565e+01, max LW: 4.5678e+01\n",
      "For BC,  min LW: 1.3028e+01, max LW: 2.2220e+01\n",
      "For PDE, min LW: 1.0851e+01, max LW: 4.3365e+01\n",
      "Iter 16000, Total: 1.3504e+00, Time: 36.96 secs\n",
      "IC: 1.0824e+00, PBC: 1.5440e-02, PDE: 2.5250e-01\n",
      "For IC,  min LW: 1.2588e+01, max LW: 4.7362e+01\n",
      "For BC,  min LW: 1.3028e+01, max LW: 2.4274e+01\n",
      "For PDE, min LW: 1.0989e+01, max LW: 4.3373e+01\n",
      "Iter 17000, Total: 3.3880e+02, Time: 37.34 secs\n",
      "IC: 3.0490e+02, PBC: 4.7535e-03, PDE: 3.3899e+01\n",
      "For IC,  min LW: 1.3521e+01, max LW: 5.4738e+01\n",
      "For BC,  min LW: 1.3609e+01, max LW: 2.6391e+01\n",
      "For PDE, min LW: 1.1832e+01, max LW: 5.1984e+01\n",
      "Iter 18000, Total: 4.4632e+02, Time: 35.88 secs\n",
      "IC: 4.0125e+02, PBC: 1.5454e-03, PDE: 4.5061e+01\n",
      "For IC,  min LW: 1.4195e+01, max LW: 6.1375e+01\n",
      "For BC,  min LW: 1.3752e+01, max LW: 2.7425e+01\n",
      "For PDE, min LW: 1.1965e+01, max LW: 5.9326e+01\n",
      "Iter 19000, Total: 5.5381e+02, Time: 36.78 secs\n",
      "IC: 4.9300e+02, PBC: 1.1485e-02, PDE: 6.0802e+01\n",
      "For IC,  min LW: 1.5520e+01, max LW: 6.7866e+01\n",
      "For BC,  min LW: 1.3837e+01, max LW: 2.7446e+01\n",
      "For PDE, min LW: 1.2443e+01, max LW: 6.6683e+01\n",
      "Iter 20000, Total: 9.5858e+00, Time: 37.52 secs\n",
      "IC: 7.2199e+00, PBC: 8.4051e-03, PDE: 2.3575e+00\n",
      "For IC,  min LW: 1.6268e+01, max LW: 6.8936e+01\n",
      "For BC,  min LW: 1.4051e+01, max LW: 2.7880e+01\n",
      "For PDE, min LW: 1.2986e+01, max LW: 6.7468e+01\n",
      "Iter 21000, Total: 6.4203e+00, Time: 37.59 secs\n",
      "IC: 4.8789e+00, PBC: 6.3393e-03, PDE: 1.5351e+00\n",
      "For IC,  min LW: 1.6277e+01, max LW: 6.9148e+01\n",
      "For BC,  min LW: 1.4051e+01, max LW: 2.7933e+01\n",
      "For PDE, min LW: 1.3068e+01, max LW: 6.7468e+01\n",
      "Iter 22000, Total: 4.6843e+00, Time: 36.85 secs\n",
      "IC: 3.5645e+00, PBC: 3.5252e-03, PDE: 1.1163e+00\n",
      "For IC,  min LW: 1.6277e+01, max LW: 6.9350e+01\n",
      "For BC,  min LW: 1.4051e+01, max LW: 2.7981e+01\n",
      "For PDE, min LW: 1.3141e+01, max LW: 6.7468e+01\n",
      "Iter 23000, Total: 3.9332e+00, Time: 37.12 secs\n",
      "IC: 3.0563e+00, PBC: 2.0927e-03, PDE: 8.7483e-01\n",
      "For IC,  min LW: 1.6279e+01, max LW: 6.9592e+01\n",
      "For BC,  min LW: 1.4052e+01, max LW: 2.8015e+01\n",
      "For PDE, min LW: 1.3196e+01, max LW: 6.7468e+01\n",
      "Iter 24000, Total: 3.5650e+00, Time: 36.96 secs\n",
      "IC: 2.8376e+00, PBC: 2.2395e-03, PDE: 7.2525e-01\n",
      "For IC,  min LW: 1.6284e+01, max LW: 6.9933e+01\n",
      "For BC,  min LW: 1.4053e+01, max LW: 2.8049e+01\n",
      "For PDE, min LW: 1.3233e+01, max LW: 6.7468e+01\n",
      "Iter 25000, Total: 3.2995e+00, Time: 36.15 secs\n",
      "IC: 2.6879e+00, PBC: 2.1984e-03, PDE: 6.0937e-01\n",
      "For IC,  min LW: 1.6291e+01, max LW: 7.0444e+01\n",
      "For BC,  min LW: 1.4054e+01, max LW: 2.8101e+01\n",
      "For PDE, min LW: 1.3284e+01, max LW: 6.7468e+01\n",
      "Iter 26000, Total: 3.0955e+00, Time: 36.12 secs\n",
      "IC: 2.5645e+00, PBC: 1.4858e-03, PDE: 5.2958e-01\n",
      "For IC,  min LW: 1.6300e+01, max LW: 7.1220e+01\n",
      "For BC,  min LW: 1.4055e+01, max LW: 2.8168e+01\n",
      "For PDE, min LW: 1.3358e+01, max LW: 6.7468e+01\n",
      "Iter 27000, Total: 2.9652e+00, Time: 36.16 secs\n",
      "IC: 2.4709e+00, PBC: 6.4576e-04, PDE: 4.9363e-01\n",
      "For IC,  min LW: 1.6308e+01, max LW: 7.2412e+01\n",
      "For BC,  min LW: 1.4056e+01, max LW: 2.8228e+01\n",
      "For PDE, min LW: 1.3423e+01, max LW: 6.7468e+01\n",
      "Iter 28000, Total: 2.9173e+00, Time: 36.31 secs\n",
      "IC: 2.4182e+00, PBC: 1.9888e-04, PDE: 4.9888e-01\n",
      "For IC,  min LW: 1.6311e+01, max LW: 7.4248e+01\n",
      "For BC,  min LW: 1.4057e+01, max LW: 2.8262e+01\n",
      "For PDE, min LW: 1.3443e+01, max LW: 6.7468e+01\n",
      "Iter 29000, Total: 2.9667e+00, Time: 36.19 secs\n",
      "IC: 2.4252e+00, PBC: 5.9111e-05, PDE: 5.4141e-01\n",
      "For IC,  min LW: 1.6312e+01, max LW: 7.6961e+01\n",
      "For BC,  min LW: 1.4058e+01, max LW: 2.8272e+01\n",
      "For PDE, min LW: 1.3488e+01, max LW: 6.7468e+01\n",
      "Iter 30000, Total: 3.5155e+00, Time: 35.14 secs\n",
      "IC: 2.8353e+00, PBC: 8.2877e-04, PDE: 6.7936e-01\n",
      "For IC,  min LW: 1.6754e+01, max LW: 7.8806e+01\n",
      "For BC,  min LW: 1.4470e+01, max LW: 2.8680e+01\n",
      "For PDE, min LW: 1.4015e+01, max LW: 6.7875e+01\n",
      "Iter 31000, Total: 3.4188e+00, Time: 34.03 secs\n",
      "IC: 2.8117e+00, PBC: 1.9140e-04, PDE: 6.0685e-01\n",
      "For IC,  min LW: 1.6757e+01, max LW: 8.0013e+01\n",
      "For BC,  min LW: 1.4475e+01, max LW: 2.8681e+01\n",
      "For PDE, min LW: 1.4035e+01, max LW: 6.7875e+01\n",
      "Iter 32000, Total: 3.4099e+00, Time: 32.69 secs\n",
      "IC: 2.8208e+00, PBC: 1.3104e-04, PDE: 5.8893e-01\n",
      "For IC,  min LW: 1.6761e+01, max LW: 8.1937e+01\n",
      "For BC,  min LW: 1.4478e+01, max LW: 2.8682e+01\n",
      "For PDE, min LW: 1.4038e+01, max LW: 6.7875e+01\n",
      "Iter 33000, Total: 3.4584e+00, Time: 32.91 secs\n",
      "IC: 2.8323e+00, PBC: 2.0303e-04, PDE: 6.2585e-01\n",
      "For IC,  min LW: 1.6763e+01, max LW: 8.4811e+01\n",
      "For BC,  min LW: 1.4486e+01, max LW: 2.8696e+01\n",
      "For PDE, min LW: 1.4039e+01, max LW: 6.7875e+01\n",
      "Iter 34000, Total: 3.5923e+00, Time: 33.89 secs\n",
      "IC: 2.8613e+00, PBC: 2.6693e-04, PDE: 7.3072e-01\n",
      "For IC,  min LW: 1.6763e+01, max LW: 8.8603e+01\n",
      "For BC,  min LW: 1.4500e+01, max LW: 2.8753e+01\n",
      "For PDE, min LW: 1.4043e+01, max LW: 6.7875e+01\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 35000, Total: 3.8209e+00, Time: 34.03 secs\n",
      "IC: 2.9531e+00, PBC: 5.3030e-04, PDE: 8.6723e-01\n",
      "For IC,  min LW: 1.6768e+01, max LW: 9.3036e+01\n",
      "For BC,  min LW: 1.4504e+01, max LW: 2.8978e+01\n",
      "For PDE, min LW: 1.4045e+01, max LW: 6.7875e+01\n",
      "Iter 36000, Total: 4.1410e+00, Time: 34.03 secs\n",
      "IC: 3.1359e+00, PBC: 9.9094e-04, PDE: 1.0042e+00\n",
      "For IC,  min LW: 1.6789e+01, max LW: 9.7915e+01\n",
      "For BC,  min LW: 1.4510e+01, max LW: 2.9814e+01\n",
      "For PDE, min LW: 1.4052e+01, max LW: 6.7875e+01\n",
      "Iter 37000, Total: 4.5711e+00, Time: 34.08 secs\n",
      "IC: 3.4360e+00, PBC: 8.0645e-04, PDE: 1.1343e+00\n",
      "For IC,  min LW: 1.6896e+01, max LW: 1.0314e+02\n",
      "For BC,  min LW: 1.4578e+01, max LW: 3.1605e+01\n",
      "For PDE, min LW: 1.4114e+01, max LW: 6.7875e+01\n",
      "Iter 38000, Total: 5.4523e+00, Time: 34.10 secs\n",
      "IC: 4.3574e+00, PBC: 2.6103e-03, PDE: 1.0923e+00\n",
      "For IC,  min LW: 1.7409e+01, max LW: 1.0654e+02\n",
      "For BC,  min LW: 1.5514e+01, max LW: 3.3079e+01\n",
      "For PDE, min LW: 1.4621e+01, max LW: 6.8353e+01\n",
      "Iter 39000, Total: 5.3637e+00, Time: 34.06 secs\n",
      "IC: 4.3517e+00, PBC: 1.3078e-03, PDE: 1.0107e+00\n",
      "For IC,  min LW: 1.7422e+01, max LW: 1.0737e+02\n",
      "For BC,  min LW: 1.5739e+01, max LW: 3.3301e+01\n",
      "For PDE, min LW: 1.4741e+01, max LW: 6.8353e+01\n",
      "Iter 40000, Total: 5.3118e+00, Time: 34.21 secs\n",
      "IC: 4.2879e+00, PBC: 6.7411e-04, PDE: 1.0232e+00\n",
      "For IC,  min LW: 1.7437e+01, max LW: 1.0871e+02\n",
      "For BC,  min LW: 1.5947e+01, max LW: 3.3500e+01\n",
      "For PDE, min LW: 1.4808e+01, max LW: 6.8372e+01\n",
      "Iter 41000, Total: 5.2280e+00, Time: 34.10 secs\n",
      "IC: 4.0826e+00, PBC: 2.9404e-04, PDE: 1.1451e+00\n",
      "For IC,  min LW: 1.7445e+01, max LW: 1.1070e+02\n",
      "For BC,  min LW: 1.6086e+01, max LW: 3.3673e+01\n",
      "For PDE, min LW: 1.4832e+01, max LW: 6.8801e+01\n",
      "Iter 42000, Total: 5.3340e+00, Time: 33.44 secs\n",
      "IC: 4.0029e+00, PBC: 1.4830e-04, PDE: 1.3309e+00\n",
      "For IC,  min LW: 1.7450e+01, max LW: 1.1336e+02\n",
      "For BC,  min LW: 1.6130e+01, max LW: 3.3822e+01\n",
      "For PDE, min LW: 1.4837e+01, max LW: 6.9839e+01\n",
      "Iter 43000, Total: 5.5765e+00, Time: 57.96 secs\n",
      "IC: 4.1331e+00, PBC: 5.8815e-05, PDE: 1.4433e+00\n",
      "For IC,  min LW: 1.7474e+01, max LW: 1.1692e+02\n",
      "For BC,  min LW: 1.6693e+01, max LW: 3.4230e+01\n",
      "For PDE, min LW: 1.4843e+01, max LW: 7.1698e+01\n",
      "Iter 44000, Total: 5.9382e+00, Time: 59.17 secs\n",
      "IC: 4.3601e+00, PBC: 4.0132e-04, PDE: 1.5778e+00\n",
      "For IC,  min LW: 1.7528e+01, max LW: 1.2137e+02\n",
      "For BC,  min LW: 1.7396e+01, max LW: 3.4856e+01\n",
      "For PDE, min LW: 1.4863e+01, max LW: 7.4706e+01\n",
      "Iter 45000, Total: 9.8888e+00, Time: 60.35 secs\n",
      "IC: 7.6994e+00, PBC: 7.6035e-02, PDE: 2.1133e+00\n",
      "For IC,  min LW: 1.8524e+01, max LW: 1.2396e+02\n",
      "For BC,  min LW: 1.8317e+01, max LW: 3.5845e+01\n",
      "For PDE, min LW: 1.5442e+01, max LW: 7.6616e+01\n",
      "Iter 46000, Total: 8.2341e+00, Time: 60.02 secs\n",
      "IC: 6.4839e+00, PBC: 3.5808e-02, PDE: 1.7144e+00\n",
      "For IC,  min LW: 1.8554e+01, max LW: 1.2472e+02\n",
      "For BC,  min LW: 1.8320e+01, max LW: 3.5877e+01\n",
      "For PDE, min LW: 1.5492e+01, max LW: 7.6767e+01\n",
      "Iter 47000, Total: 7.6494e+00, Time: 59.52 secs\n",
      "IC: 5.9910e+00, PBC: 2.0868e-02, PDE: 1.6375e+00\n",
      "For IC,  min LW: 1.8588e+01, max LW: 1.2582e+02\n",
      "For BC,  min LW: 1.8323e+01, max LW: 3.5905e+01\n",
      "For PDE, min LW: 1.5527e+01, max LW: 7.7071e+01\n",
      "Iter 48000, Total: 7.3130e+00, Time: 60.02 secs\n",
      "IC: 5.6160e+00, PBC: 1.4537e-02, PDE: 1.6825e+00\n",
      "For IC,  min LW: 1.8607e+01, max LW: 1.2742e+02\n",
      "For BC,  min LW: 1.8323e+01, max LW: 3.5917e+01\n",
      "For PDE, min LW: 1.5565e+01, max LW: 7.7744e+01\n",
      "Iter 49000, Total: 7.2482e+00, Time: 60.30 secs\n",
      "IC: 5.4156e+00, PBC: 9.4607e-03, PDE: 1.8231e+00\n",
      "For IC,  min LW: 1.8609e+01, max LW: 1.2968e+02\n",
      "For BC,  min LW: 1.8323e+01, max LW: 3.5925e+01\n",
      "For PDE, min LW: 1.5624e+01, max LW: 7.9209e+01\n",
      "Iter 50000, Total: 7.4013e+00, Time: 60.43 secs\n",
      "IC: 5.4739e+00, PBC: 5.6071e-03, PDE: 1.9218e+00\n",
      "For IC,  min LW: 1.8632e+01, max LW: 1.3283e+02\n",
      "For BC,  min LW: 1.8323e+01, max LW: 3.5937e+01\n",
      "For PDE, min LW: 1.5659e+01, max LW: 8.1689e+01\n",
      "Starting with L-BFGS\n",
      "Iter  1000, Total: 6.9120e+00, Time: 126.62 secs\n",
      "IC: 4.9788e+00, PBC: 7.3803e-05, PDE: 1.9331e+00\n"
     ]
    }
   ],
   "source": [
    "layers = [2, 32, 32, 32, 32, 32, 32, 32, 1]\n",
    "N_IC = 96\n",
    "ptsIC=np.load('ptsIC.npy')\n",
    "ptsBC=np.load('ptsBC.npy')\n",
    "ptsPDE=np.load('random_data.npy')\n",
    "x_IC = np.expand_dims(np.linspace(xlo, xhi, N_IC), axis = 1)\n",
    "u_IC= u0(x_IC)\n",
    "model = PhysicsInformedNN(period, epsilon, gamma, ptsIC, u_IC, ptsBC, ptsPDE, layers)\n",
    "model.train(50000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = np.linspace(tlo, thi, 101)\n",
    "x = np.linspace(xlo, xhi, 201)\n",
    "T, X = np.meshgrid(t, x)\n",
    "pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))\n",
    "u_pred = model.predict(pts_flat)           \n",
    "u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = scipy.io.loadmat('Data/AC_case1.mat')\n",
    "t = data['t'].flatten()[:,None]\n",
    "x2 = data['x'].flatten()[:,None]\n",
    "exact_sol = np.real(data['exact_sol']).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_error_l2(pred,exact):\n",
    "    error_l2 = np.sqrt(np.sum(np.power(pred - exact,2)))\n",
    "    relative = error_l2/np.sqrt(np.sum(np.power(exact,2)))\n",
    "    return relative\n",
    "def relative_error_l1(pred,exact):\n",
    "    error_l1 = np.sum(np.abs(pred-exact))\n",
    "    relative = error_l1/np.sum(np.abs(exact))\n",
    "    return relative\n",
    "def relative_error_linf(pred,exact):\n",
    "    error_linf = np.max(np.abs(pred-exact))\n",
    "    relative = error_linf/np.max(np.abs(exact))\n",
    "    return relative"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "l2: 0.51208919302061\n",
      "l1: 0.3183052953987406\n",
      "linf: 1.1833447637025254\n"
     ]
    }
   ],
   "source": [
    "print(f'l2: {relative_error_l2(u_pred.T,exact_sol)}')\n",
    "print(f'l1: {relative_error_l1(u_pred.T,exact_sol)}')\n",
    "print(f'linf: {relative_error_linf(u_pred.T,exact_sol)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
