{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T05:49:56.886747Z",
     "start_time": "2020-05-05T05:49:55.552862Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.267254Z",
     "start_time": "2020-05-05T06:17:55.260453Z"
    }
   },
   "outputs": [],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.441753Z",
     "start_time": "2020-05-05T06:17:55.436955Z"
    }
   },
   "outputs": [],
   "source": [
    "# Make grid points\n",
    "tmin, tmax = 0, np.pi**2\n",
    "xmin, xmax = 0, np.pi\n",
    "Ns, Nx = 51, 51\n",
    "sx = np.mgrid[tmin:tmax:51j, xmin:xmax:51j].reshape(2, -1).T\n",
    "df_real = pd.DataFrame(sx, columns=[\"t\", 'x'])\n",
    "X_train = df_real.values\n",
    "\n",
    "def initial_condition(row) :\n",
    "    t, x = row\n",
    "    return np.sin(x)\n",
    "\n",
    "X_ini = df_real.loc[df_real['t']==0]\n",
    "u_ini = X_ini.apply(initial_condition, axis=1)\n",
    "\n",
    "X_ini, u_ini = torch.FloatTensor(X_ini.values).to(device),\\\n",
    "                torch.FloatTensor(u_ini.values).to(device).view(-1,1)\n",
    "\n",
    "def analytic(row) :\n",
    "    t, x = row\n",
    "    return np.sin(x)*np.exp(-t)\n",
    "def analytic_x(row) :\n",
    "    t, x = row\n",
    "    return np.cos(x)*np.exp(-t)\n",
    "\n",
    "X_test = pd.DataFrame(np.mgrid[tmin:tmax:101j, xmin:xmax:101j].reshape(2, -1).T)\n",
    "y_test = X_test.apply(analytic, axis=1).values\n",
    "\n",
    "X_test = torch.FloatTensor(X_test.values).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.441753Z",
     "start_time": "2020-05-05T06:17:55.436955Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# generate boundary data for Neumann boundary condition\n",
    "X_bdry = df_real.loc[df_real['x'].isin([xmin, df_real.iloc[-1]['x']])]\n",
    "X_bdry = Variable(torch.FloatTensor(X_bdry.values).to(device), requires_grad=True)\n",
    "u_bdry = Variable(torch.zeros_like(X_bdry[:,0]).to(device)).view(-1,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.245996Z",
     "start_time": "2020-05-05T06:17:58.241527Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class u_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(u_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(2, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)    \n",
    "        self.fc3 = nn.Linear(256, 256)    \n",
    "        self.fc4 = nn.Linear(256, 256)    \n",
    "        self.fc5 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "    def forward(self, x):\n",
    "        space = x[:, 1].view(-1,1)\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.act1(self.fc4(x))\n",
    "        x = self.fc5(x)*(space)*(space-np.pi)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.248577Z",
     "start_time": "2020-05-05T06:17:58.246944Z"
    }
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 10201"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.251546Z",
     "start_time": "2020-05-05T06:17:58.249443Z"
    }
   },
   "outputs": [],
   "source": [
    "# Make dataloader\n",
    "data_train = TensorDataset(torch.FloatTensor(X_train))\n",
    "train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.255550Z",
     "start_time": "2020-05-05T06:17:58.252428Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]\n",
    "\n",
    "def calculate_all_partial(u, x) :\n",
    "    del_u = calculate_derivative(u, x)\n",
    "    u_t, u_x = del_u[:,0], del_u[:,1]\n",
    "    u_xx = calculate_derivative(u_x, x)[:,1]\n",
    "    return u_t, u_x, u_xx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_L2(output) :\n",
    "    L2s = np.mean(np.split((output.view(-1).cpu().detach().numpy() - y_test)**2, 101), axis=1)\n",
    "    return max(L2s)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.262409Z",
     "start_time": "2020-05-05T06:17:58.256729Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(uv_model, trainloader, optimizer, loss_f) :\n",
    "    uv_model.train()\n",
    "    sig = nn.Sigmoid()\n",
    "    loss_list, loss_list1, loss_list2, loss_list3, err_list = [], [], [], [], []\n",
    "    \n",
    "    for i, (data,) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        X_v = Variable(data, requires_grad=True).to(device)\n",
    "        output = uv_model(X_v)  \n",
    "        output_ini = uv_model(X_ini)\n",
    "        output_bdry = uv_model(X_bdry)\n",
    "        \n",
    "        u_t, output_x, u_xx = calculate_all_partial(output, X_v)\n",
    "        loss1 = loss_f(u_t-u_xx, torch.zeros_like(u_t))#((u_t-u_xx)**2).mean()#\n",
    "        loss2 = loss_f(output_ini, u_ini)\n",
    "        loss3 = loss_f(output_bdry, torch.zeros_like(output_bdry))\n",
    "        \n",
    "        loss =(1/3)*(tmax*xmax*loss1 + xmax*loss2 + 2*tmax*loss3)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        err = max_L2(uv_model(X_test))\n",
    "        loss_list.append((loss).item())\n",
    "        loss_list1.append(loss1.item())\n",
    "        loss_list2.append(loss2.item())\n",
    "        loss_list3.append(loss3.item())\n",
    "        err_list.append(err)\n",
    "    return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2), np.mean(loss_list3), np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.271852Z",
     "start_time": "2020-05-05T06:17:58.263364Z"
    }
   },
   "outputs": [],
   "source": [
    "data = torch.FloatTensor(df_real.values).to(device)\n",
    "\n",
    "def plot_results(u_model, show) :\n",
    "    prediction = (u_model(data)).cpu().detach().numpy()\n",
    "    prediction_u = prediction[:,0].reshape(Ns, Nx)\n",
    "    fig = plt.figure(figsize=(10,5))\n",
    "\n",
    "    s = np.linspace(tmin, tmax, Ns) # discretization of space\n",
    "    x = np.linspace(xmin, xmax, Nx) # discretization of space\n",
    "    S, X = np.meshgrid(s, x)\n",
    "\n",
    "    ax = fig.add_subplot(1,1,1, projection='3d')\n",
    "    surf = ax.plot_surface(S, X, prediction_u.T, alpha=0.7)\n",
    "    surf._facecolors2d=surf._facecolors3d\n",
    "    surf._edgecolors2d=surf._edgecolors3d\n",
    "\n",
    "    ax.zaxis._axinfo['juggled'] = (1,2,2)\n",
    "    ax.set_title('Neural Network Solution')\n",
    "    ax.set_xlabel('t')\n",
    "    ax.set_ylabel('x')\n",
    "    ax.set_zlabel('u')\n",
    "    ax.locator_params(axis='z', nbins=6)\n",
    "    ax.locator_params(axis='x', nbins=5)\n",
    "    ax.locator_params(axis='y', nbins=5)\n",
    "    \n",
    "    if show :\n",
    "        plt.show()\n",
    "    else :\n",
    "        return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:34.873952Z",
     "start_time": "2020-05-05T08:22:34.849481Z"
    }
   },
   "outputs": [],
   "source": [
    "u_model =  u_Net()      \n",
    "total_loss = []\n",
    "loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "u_model = u_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:57.203176Z",
     "start_time": "2020-05-05T08:22:35.054371Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-5)\n",
    "\n",
    "EPOCH=50000\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    \n",
    "    loss, loss1, loss2, loss3, err = train(u_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    \n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss: %06.6f | loss1: %06.6f | loss2: %06.6f | loss3 : %06.6f | l2 error : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, loss1, loss2, loss3, err))\n",
    "        \n",
    "    if t%10000 ==0 :\n",
    "        plot_results(u_model, show=True)\n",
    "        \n",
    "    total_loss.append(loss)\n",
    "    loss1s.append(loss1)\n",
    "    loss2s.append(loss2)\n",
    "    loss3s.append(loss3)\n",
    "    errs.append(err)\n",
    "    # Save Model\n",
    "    if t % 1000 == 0:\n",
    "        torch.save([u_model, total_loss, loss1s, loss2s, loss3s, errs],\n",
    "                   '../../models/1D_Heat_Dirichlet.pt')\n",
    "        \n",
    "    if err<1e-5 :\n",
    "        print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### For Error Plot\n",
    "\n",
    "import time\n",
    "\n",
    "epochs = []\n",
    "times = []\n",
    "errs_list = []\n",
    "\n",
    "for i in tqdm_notebook(range(100)) :\n",
    "    u_model =  u_Net().to(device)\n",
    "    total_loss = []\n",
    "    loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "    optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-5)\n",
    "\n",
    "    EPOCH=10000\n",
    "    a = time.time()\n",
    "    saved=False\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "\n",
    "        loss, loss1, loss2, loss3, err = train(u_model, trainloader=train_loader, \\\n",
    "                                      optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "        errs.append(err)\n",
    "        # Print Log\n",
    "        if (err<1e-5)&(saved==False) :\n",
    "            print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "            epochs.append(t)\n",
    "            b = time.time()\n",
    "            times.append(b-a)\n",
    "            saved=True\n",
    "    if saved==False :\n",
    "        epochs.append(EPOCH)\n",
    "    errs_list.append(errs)\n",
    "    with open('Heat_initial_L2_GE_L2_1e-5_error_plot', 'wb') as f:\n",
    "        pkl.dump([epochs, times, errs_list], f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
