{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T05:49:56.886747Z",
     "start_time": "2020-05-05T05:49:55.552862Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make grid points\n",
    "\n",
    "tmin, tmax = 0.0001, 0.001\n",
    "xmin, xmax = -1, 1\n",
    "Ns, Nx = 51, 51\n",
    "sx = np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T\n",
    "df_real = pd.DataFrame(sx, columns=[\"t\", 'x'])\n",
    "X_train = df_real.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.267254Z",
     "start_time": "2020-05-05T06:17:55.260453Z"
    }
   },
   "outputs": [],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eta = np.linspace(-100, 100, 10000)\n",
    "nu = 1/5#0.01/np.pi\n",
    "def f(x) :\n",
    "    return np.exp(-np.cos(np.pi*x)/(2*np.pi*nu))\n",
    "def analytic(row) :\n",
    "    t, x = row \n",
    "    denom = np.trapz(-np.sin(np.pi*(x - eta))*f(x-eta)*np.exp(-(eta**2)/(4*t*nu)), eta)\n",
    "    num = np.trapz(f(x-eta)*np.exp(-(eta**2)/(4*t*nu)), eta)\n",
    "    return denom/num\n",
    "\n",
    "X_test = pd.DataFrame(np.mgrid[tmin:tmax:101j, xmin:xmax:101j].reshape(2, -1).T)\n",
    "y_test = X_test.apply(analytic, axis=1).values\n",
    "\n",
    "X_test = torch.FloatTensor(X_test.values).to(device)\n",
    "\n",
    "u_analytic = df_real.apply(analytic, axis=1).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10,5))\n",
    "s = np.linspace(tmin, tmax, 101) # discretization of space\n",
    "x = np.linspace(xmin, xmax, 101) # discretization of space\n",
    "S, X = np.meshgrid(s, x)\n",
    "\n",
    "ax = fig.add_subplot(1,1,1, projection='3d')\n",
    "surf = ax.plot_surface(S, X, y_test.reshape(101,101).T, alpha=0.7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.230485Z",
     "start_time": "2020-05-05T06:17:55.589878Z"
    }
   },
   "outputs": [],
   "source": [
    "X_ini = np.concatenate([np.zeros([100,1]), np.linspace(-1,1, 100).reshape(-1,1)], axis=1)\n",
    "u_ini = -np.sin(np.pi*X_ini[:,1])\n",
    "u_ini_der = -np.pi*np.cos(np.pi*X_ini[:,1])\n",
    "u_ini_sec_der = (np.pi**2)*np.sin(np.pi*X_ini[:,1])\n",
    "\n",
    "X_ini, u_ini, u_ini_der, u_ini_sec_der = Variable(torch.FloatTensor(X_ini).to(device), requires_grad=True),\\\n",
    "                torch.FloatTensor(u_ini).to(device).view(-1,1),\\\n",
    "                torch.FloatTensor(u_ini_der).to(device).view(-1,1),\\\n",
    "                torch.FloatTensor(u_ini_sec_der).to(device).view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(u_ini.cpu().detach().numpy())\n",
    "plt.plot(u_ini_der.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_bdry = df_real[df_real['x'].isin([xmin, xmax])].values\n",
    "X_bdry = torch.FloatTensor(X_bdry).to(device)\n",
    "u_bdry = torch.zeros(len(X_bdry),1).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.245996Z",
     "start_time": "2020-05-05T06:17:58.241527Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class u_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(u_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)    \n",
    "        self.fc3 = nn.Linear(256, 256)    \n",
    "        self.fc4 = nn.Linear(256, 256)    \n",
    "        self.fc5 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        space = x[:, 1].view(-1,1)\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.act1(self.fc4(x))\n",
    "        x = self.fc5(x)*(space+1)*(1-space)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.248577Z",
     "start_time": "2020-05-05T06:17:58.246944Z"
    }
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.251546Z",
     "start_time": "2020-05-05T06:17:58.249443Z"
    }
   },
   "outputs": [],
   "source": [
    "# Make dataloader\n",
    "data_train = TensorDataset(torch.FloatTensor(X_train))\n",
    "train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.255550Z",
     "start_time": "2020-05-05T06:17:58.252428Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]\n",
    "\n",
    "def calculate_all_partial(u, x) :\n",
    "    del_u = calculate_derivative(u, x)\n",
    "    u_t, u_x = del_u[:,0], del_u[:,1]\n",
    "    u_xx = calculate_derivative(u_x, x)[:,1]\n",
    "    return u_t, u_x, u_xx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_GE(output, u_t, u_x, u_xx, f_t) :\n",
    "    f = u_t + output.view(-1)*u_x -u_xx\n",
    "    L2_term = (tmax-tmin)*(xmax-xmin)*2*(f**2).mean()\n",
    "    H1_term = (tmax-tmin)*(((xmax-xmin)*torch.mean(torch.stack(torch.split(f.view(-1)**2, Ns)), dim=1))**(-1/2) \\\n",
    "               * ((xmax-xmin)*torch.mean(torch.stack(torch.split((f*f_t).view(-1)**2, Ns)), dim=1))).mean()\n",
    "    return L2_term + H1_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_L2(output) :\n",
    "    L2s = np.mean(np.split((output.view(-1).cpu().detach().numpy() - y_test)**2, 101), axis=1)\n",
    "    return max(L2s)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.262409Z",
     "start_time": "2020-05-05T06:17:58.256729Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(uv_model, trainloader, optimizer, loss_f) :\n",
    "    uv_model.train()\n",
    "    loss_list, loss_list1, loss_list2, loss_list3, err_list = [], [], [], [], []\n",
    "    \n",
    "    for i, (data,) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        X_v = Variable(data, requires_grad=True).to(device)\n",
    "        output = uv_model(X_v)  \n",
    "        output_ini = uv_model(X_ini)\n",
    "        output_ini_der = calculate_derivative(output_ini, X_ini)[:,1]\n",
    "        output_ini_sec_der = calculate_derivative(output_ini_der, X_ini)[:,1]\n",
    "        output_bdry = uv_model(X_bdry)\n",
    "        \n",
    "        u_t, u_x, u_xx = calculate_all_partial(output, X_v)\n",
    "        f_t = calculate_derivative(u_t + output.view(-1)*u_x - nu*u_xx, X_v)[:,0]\n",
    "        \n",
    "        loss1 = loss_GE(output, u_t, u_x, u_xx, f_t)\n",
    "        loss2 = loss_f(output_ini, u_ini) + loss_f(output_ini_der.view(-1,1), u_ini_der) +\\\n",
    "                loss_f(output_ini_sec_der.view(-1,1), u_ini_sec_der)\n",
    "        loss3 = loss_f(output_bdry, u_bdry)\n",
    "        \n",
    "        loss =(1/3)*(loss1 + (xmax-xmin)*loss2 + 2*(tmax*tmin)*loss3)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        err = max_L2(uv_model(X_test))\n",
    "        loss_list.append((loss).item())\n",
    "        loss_list1.append(loss1.item())\n",
    "        loss_list2.append(loss2.item())\n",
    "        loss_list3.append(loss3.item())\n",
    "        err_list.append(err)\n",
    "    return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2), np.mean(loss_list3), np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.271852Z",
     "start_time": "2020-05-05T06:17:58.263364Z"
    }
   },
   "outputs": [],
   "source": [
    "data = torch.FloatTensor(df_real.values).to(device)\n",
    "\n",
    "def plot_results(u_model, show) :\n",
    "    prediction = (u_model(data)).cpu().detach().numpy()\n",
    "    prediction_u = prediction[:,0].reshape(Ns, Nx)\n",
    "    fig = plt.figure(figsize=(10,5))\n",
    "\n",
    "    s = np.linspace(tmin, tmax, Ns) # discretization of space\n",
    "    x = np.linspace(xmin, xmax, Nx) # discretization of space\n",
    "    S, X = np.meshgrid(s, x)\n",
    "\n",
    "    ax = fig.add_subplot(1,1,1, projection='3d')\n",
    "    surf = ax.plot_surface(S, X, prediction_u.T, alpha=0.7)\n",
    "    surf._facecolors2d=surf._facecolors3d\n",
    "    surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "    ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "    ax.set_title('Neural Network Solution')\n",
    "    ax.set_xlabel('t')\n",
    "    ax.set_ylabel('x')\n",
    "    ax.set_zlabel('u')\n",
    "    ax.locator_params(axis='z', nbins=6)\n",
    "    ax.locator_params(axis='x', nbins=5)\n",
    "    ax.locator_params(axis='y', nbins=5)\n",
    "    \n",
    "    if show :\n",
    "        plt.show()\n",
    "    else :\n",
    "        return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:34.873952Z",
     "start_time": "2020-05-05T08:22:34.849481Z"
    }
   },
   "outputs": [],
   "source": [
    "u_model =  u_Net()      \n",
    "total_loss = []\n",
    "loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "u_model = u_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:57.203176Z",
     "start_time": "2020-05-05T08:22:35.054371Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-4)\n",
    "\n",
    "EPOCH=100000\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    \n",
    "    loss, loss1, loss2, loss3, err = train(u_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    \n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss: %06.6f | loss1: %06.6f | loss2: %06.6f | loss3 : %06.6f | l2 error : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, loss1, loss2, loss3, err))\n",
    "        \n",
    "    if t%10000 ==0 :\n",
    "        plot_results(u_model, show=True)\n",
    "        \n",
    "    total_loss.append(loss)\n",
    "    loss1s.append(loss1)\n",
    "    loss2s.append(loss2)\n",
    "    loss3s.append(loss3)\n",
    "    errs.append(err)\n",
    "    # Save Modelview(-1)\n",
    "    if t % 1000 == 0:\n",
    "        torch.save([u_model, total_loss, loss1s, loss2s, loss3s, errs],\n",
    "                   '../../models/1D_Burgers_Dirichlet_t<0.00001_H2.pt')\n",
    "        \n",
    "    if err<1e-5 :\n",
    "        print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### For Error Plot\n",
    "\n",
    "import time\n",
    "\n",
    "epochs = []\n",
    "times = []\n",
    "errs_list = []\n",
    "\n",
    "for i in tqdm_notebook(range(100)) :\n",
    "    u_model =  u_Net().to(device)\n",
    "    total_loss = []\n",
    "    loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "    optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-4)\n",
    "\n",
    "    EPOCH=1000\n",
    "    a = time.time()\n",
    "    saved=False\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "\n",
    "        loss, loss1, loss2, loss3, err = train(u_model, trainloader=train_loader, \\\n",
    "                                      optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "        errs.append(err)\n",
    "        # Print Log\n",
    "        if (err<1e-5)&(saved==False) :\n",
    "            print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "            epochs.append(t)\n",
    "            b = time.time()\n",
    "            times.append(b-a)\n",
    "            saved=True\n",
    "    errs_list.append(errs)\n",
    "    with open('Burgers_initial_H2_GE_H1_1e-5_error_plot', 'wb') as f:\n",
    "        pkl.dump([epochs, times, errs_list], f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
