{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from torch.optim import LBFGS\n",
    "from tqdm import tqdm\n",
    "import scipy.io\n",
    "\n",
    "from DCGD_BFGS import DualCenter_BFGS\n",
    "\n",
    "from util import *\n",
    "from model.pinn import PINNs\n",
    "from model.pinnsformer import PINNsformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CUDA support \n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device('cuda')\n",
    "else:\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "\n",
    "#device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train PINNsformer\n",
    "res, b_left, b_right, b_upper, b_lower = get_data([0,2*np.pi], [0,1], 51, 51)\n",
    "res_test, _, _, _, _ = get_data([0,2*np.pi], [0,1], 101, 101)\n",
    "\n",
    "res = make_time_sequence(res, num_step=5, step=1e-4)\n",
    "b_left = make_time_sequence(b_left, num_step=5, step=1e-4)\n",
    "b_right = make_time_sequence(b_right, num_step=5, step=1e-4)\n",
    "b_upper = make_time_sequence(b_upper, num_step=5, step=1e-4)\n",
    "b_lower = make_time_sequence(b_lower, num_step=5, step=1e-4)\n",
    "\n",
    "res = torch.tensor(res, dtype=torch.float32, requires_grad=True).to(device)\n",
    "b_left = torch.tensor(b_left, dtype=torch.float32, requires_grad=True).to(device)\n",
    "b_right = torch.tensor(b_right, dtype=torch.float32, requires_grad=True).to(device)\n",
    "b_upper = torch.tensor(b_upper, dtype=torch.float32, requires_grad=True).to(device)\n",
    "b_lower = torch.tensor(b_lower, dtype=torch.float32, requires_grad=True).to(device)\n",
    "\n",
    "x_res, t_res = res[:,:,0:1], res[:,:,1:2]\n",
    "x_left, t_left = b_left[:,:,0:1], b_left[:,:,1:2]\n",
    "x_right, t_right = b_right[:,:,0:1], b_right[:,:,1:2]\n",
    "x_upper, t_upper = b_upper[:,:,0:1], b_upper[:,:,1:2]\n",
    "x_lower, t_lower = b_lower[:,:,0:1], b_lower[:,:,1:2]\n",
    "\n",
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "        torch.nn.init.xavier_uniform(m.weight)\n",
    "        m.bias.data.fill_(0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PINNsformer(d_out=1, d_hidden=512, d_model=32, N=1, heads=2).to(device)\n",
    "\n",
    "model.apply(init_weights)\n",
    "optim = LBFGS(model.parameters(), line_search_fn='strong_wolfe')\n",
    "\n",
    "print(model)\n",
    "print(get_n_params(model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_optimizer = DualCenter_BFGS(optim, num_pde=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_track = []\n",
    "wr, wb = 1, 1\n",
    "\n",
    "\n",
    "\n",
    "for i in tqdm(range(500)):\n",
    "    def closure():\n",
    "        pred_res = model(x_res, t_res)\n",
    "        pred_left = model(x_left, t_left)\n",
    "        pred_right = model(x_right, t_right)\n",
    "        pred_upper = model(x_upper, t_upper)\n",
    "        pred_lower = model(x_lower, t_lower)\n",
    "\n",
    "        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]\n",
    "        u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]\n",
    "\n",
    "        loss_res = torch.mean((u_t + 50 * u_x) ** 2)\n",
    "        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)\n",
    "        loss_ic = torch.mean((pred_left[:,0] - torch.sin(x_left[:,0])) ** 2)\n",
    "\n",
    "        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])\n",
    "\n",
    "        loss = wr*loss_res + wb*(loss_bc + loss_ic)\n",
    "        optim.zero_grad()\n",
    "        loss.backward()\n",
    "        return loss\n",
    "\n",
    "    def dcgd_closure():\n",
    "        pred_res = model(x_res, t_res)\n",
    "        pred_left = model(x_left, t_left)\n",
    "        pred_right = model(x_right, t_right)\n",
    "        pred_upper = model(x_upper, t_upper)\n",
    "        pred_lower = model(x_lower, t_lower)\n",
    "\n",
    "        u_x = torch.autograd.grad(pred_res, x_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]\n",
    "        u_t = torch.autograd.grad(pred_res, t_res, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]\n",
    "\n",
    "        loss_res = torch.mean((u_t + 50 * u_x) ** 2)\n",
    "        loss_bc = torch.mean((pred_upper - pred_lower) ** 2)\n",
    "        loss_ic = torch.mean((pred_left[:,0] - torch.sin(x_left[:,0])) ** 2)\n",
    "\n",
    "        loss_track.append([loss_res.item(), loss_bc.item(), loss_ic.item()])\n",
    "        loss_bd = loss_bc+loss_ic\n",
    "        \n",
    "        return [loss_res, loss_bd]\n",
    "\n",
    "    if i % 10 == 0:\n",
    "        losses = dcgd_closure()\n",
    "        weights = weight_optimizer.step(losses)\n",
    "        wr = weights[0]\n",
    "        wb = weights[1]\n",
    "\n",
    "    optim.step(closure)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Loss Res: {:4f}, Loss_BC: {:4f}, Loss_IC: {:4f}'.format(loss_track[-1][0], loss_track[-1][1], loss_track[-1][2]))\n",
    "print('Train Loss: {:4f}'.format(np.sum(loss_track[-1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.array(loss_track)[:,0], label='res')\n",
    "plt.plot(np.array(loss_track)[:,1]+np.array(loss_track)[:,1], label='bc')\n",
    "#plt.plot(np.array(loss_track)[:,2], label='ic')\n",
    "plt.legend()\n",
    "plt.yscale('log', base=10)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize PINNsformer\n",
    "res_test = make_time_sequence(res_test, num_step=5, step=1e-4) \n",
    "res_test = torch.tensor(res_test, dtype=torch.float32, requires_grad=True).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize PINNsformer\n",
    "\n",
    "x_test, t_test = res_test[:,:,0:1], res_test[:,:,1:2]\n",
    "\n",
    "with torch.no_grad():\n",
    "    pred = model(x_test, t_test)[:,0:1]\n",
    "    pred = pred.cpu().detach().numpy()\n",
    "\n",
    "pred = pred.reshape(101,101)\n",
    "\n",
    "mat = scipy.io.loadmat('./pinnsformer/demo/convection/convection.mat')\n",
    "u = mat['u'].reshape(101,101)\n",
    "\n",
    "rl1 = np.sum(np.abs(u-pred)) / np.sum(np.abs(u))\n",
    "rl2 = np.sqrt(np.sum((u-pred)**2) / np.sum(u**2))\n",
    "\n",
    "print('relative L1 error: {:4f}'.format(rl1))\n",
    "print('relative L2 error: {:4f}'.format(rl2))\n",
    "\n",
    "plt.figure(figsize=(6,5))\n",
    "plt.imshow(pred, extent=[0,np.pi*2,1,0], aspect='auto')\n",
    "\n",
    "cbar = plt.colorbar()\n",
    "\n",
    "cbar.ax.tick_params(labelsize=15)\n",
    "\n",
    "plt.tick_params(axis='both', labelsize=20)\n",
    "plt.xlabel(r'$x$', fontsize=30)\n",
    "plt.ylabel(r'$t$', fontsize=30)\n",
    "\n",
    "#plt.title('Predicted u(x,t)')\n",
    "#plt.colorbar()\n",
    "plt.tight_layout()\n",
    "plt.savefig('./convection_pinnsformer_pred.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,5))\n",
    "plt.imshow(u, extent=[0,np.pi*2,1,0], aspect='auto')\n",
    "plt.xlabel(r'$x$', fontsize=30)\n",
    "plt.ylabel(r'$t$', fontsize=30)\n",
    "\n",
    "plt.tick_params(axis='both', labelsize=20)\n",
    "#plt.title('Exact u(x,t)')\n",
    "\n",
    "cbar = plt.colorbar()\n",
    "\n",
    "cbar.ax.tick_params(labelsize=15)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('./convection_exact.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(6,5))\n",
    "plt.imshow(np.abs(pred - u), extent=[0,np.pi*2,1,0], aspect='auto')\n",
    "plt.xlabel(r'$x$', fontsize=30)\n",
    "plt.ylabel(r'$t$', fontsize=30)\n",
    "\n",
    "plt.tick_params(axis='both', labelsize=20)\n",
    "#plt.title('Absolute Error')\n",
    "\n",
    "cbar = plt.colorbar()\n",
    "\n",
    "cbar.ax.tick_params(labelsize=15)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('./convection_dcgd_error.pdf', format='pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "leo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
