{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T05:49:56.886747Z",
     "start_time": "2020-05-05T05:49:55.552862Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.267254Z",
     "start_time": "2020-05-05T06:17:55.260453Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim = 100\n",
    "xmin, xmax = 0,1\n",
    "\n",
    "try :\n",
    "    test_data = pd.read_pickle('poisson_{}_test'.format(dim))\n",
    "except :\n",
    "    test_data = np.random.uniform(0,1,[1000,dim])\n",
    "    with open('poisson_{}_test'.format(dim), 'wb') as f :\n",
    "        pkl.dump(test_data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def analytic(data) :\n",
    "    return torch.sin((np.pi/2) * data).sum(dim=1).view(-1,1)\n",
    "\n",
    "def boundary_condition(data) : #data : N X dim\n",
    "    return torch.sin((np.pi/2) * data).sum(dim=1).view(-1,1)\n",
    "\n",
    "def sample(N):\n",
    "    return torch.rand([N,dim])\n",
    "\n",
    "def sample_boundary(N):\n",
    "    N = int(N/2)\n",
    "    interior = torch.rand([N,dim])\n",
    "    ind = int(N/dim)\n",
    "    for i in range(dim) :\n",
    "        interior[i*ind:(i+1)*ind, i] = 0\n",
    "    interior[(ind)*dim:, -1] = 0\n",
    "    \n",
    "    interior2 = torch.rand([N,dim])\n",
    "    ind = int(N/dim)\n",
    "    for i in range(dim) :\n",
    "        interior2[i*ind:(i+1)*ind, i] = 1\n",
    "    interior2[ind*dim:, -1] = 1\n",
    "    \n",
    "    return torch.cat([interior, interior2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_x = torch.FloatTensor(test_data).to(device)\n",
    "answer_y = analytic(answer_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.245996Z",
     "start_time": "2020-05-05T06:17:58.241527Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class u_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(u_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(dim, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)    \n",
    "        self.fc3 = nn.Linear(256, 256)   \n",
    "        self.fc6 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.fc6(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.248577Z",
     "start_time": "2020-05-05T06:17:58.246944Z"
    }
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.262409Z",
     "start_time": "2020-05-05T06:17:58.256729Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(uv_model, trainloader, optimizer, loss_f) :\n",
    "    uv_model.train()\n",
    "    loss_list, loss_list1, loss_list2, err_list = [], [], [], []\n",
    "    \n",
    "    for i, (X_v, X_bdry, u_bdry) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        X_v = Variable(X_v, requires_grad=True).to(device)\n",
    "        X_bdry, u_bdry = X_bdry.to(device), u_bdry.to(device)\n",
    "        \n",
    "        output = uv_model(X_v)  \n",
    "        output_bdry = uv_model(X_bdry)\n",
    "        \n",
    "        del_u = calculate_derivative(output, X_v)\n",
    "        u_xx = sum(calculate_derivative(del_u[:,j], X_v)[:,j] for j in range(dim)).view(-1,1)\n",
    "        \n",
    "        loss1 = loss_f(u_xx + ((np.pi**2)/4)*torch.sin((np.pi/2)*X_v).sum(dim=1).view(-1,1), torch.zeros_like(u_xx))#((u_t-u_xx)**2).mean()#\n",
    "        loss2 = loss_f(output_bdry, u_bdry)\n",
    "        \n",
    "        loss = loss1 + loss2\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        #err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "        loss_list.append((loss).item())\n",
    "        loss_list1.append(loss1.item())\n",
    "        loss_list2.append(loss2.item())\n",
    "        #err_list.append(err)\n",
    "    return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2)#, np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:34.873952Z",
     "start_time": "2020-05-05T08:22:34.849481Z"
    }
   },
   "outputs": [],
   "source": [
    "u_model =  u_Net()      \n",
    "total_loss = []\n",
    "loss1s, loss2s, errs = [], [], []\n",
    "u_model = u_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_results(output_answer) :\n",
    "    fig = plt.figure(figsize=[8,4])\n",
    "    plt.plot(answer_x[:,1].cpu().detach().numpy(), answer_y.cpu().detach().numpy(), 'ro', label='True')\n",
    "    plt.plot(answer_x[:,1].cpu().detach().numpy(), output_answer.cpu().detach().numpy(), 'bo', label='Pred')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:57.203176Z",
     "start_time": "2020-05-05T08:22:35.054371Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45ce7a60b57d43518b114769a54fea55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100/10000 | loss: 26645.281250 | loss1: 23472.025391 | loss2: 3173.255127| rel error : 0.884877 \n",
      "200/10000 | loss: 16707.308594 | loss1: 13890.980469 | loss2: 2816.327637| rel error : 0.832489 \n",
      "300/10000 | loss: 14144.958984 | loss1: 11543.250977 | loss2: 2601.707520| rel error : 0.803365 \n",
      "400/10000 | loss: 13608.087891 | loss1: 11165.820312 | loss2: 2442.268066| rel error : 0.774209 \n",
      "500/10000 | loss: 11621.081055 | loss1: 9318.616211 | loss2: 2302.465088| rel error : 0.745995 \n",
      "600/10000 | loss: 10491.020508 | loss1: 8131.355957 | loss2: 2359.664307| rel error : 0.756258 \n",
      "700/10000 | loss: 7091.994629 | loss1: 4754.462402 | loss2: 2337.532227| rel error : 0.749571 \n",
      "800/10000 | loss: 3958.283691 | loss1: 1292.126831 | loss2: 2666.156738| rel error : 0.800078 \n",
      "900/10000 | loss: 3047.910645 | loss1: 501.182159 | loss2: 2546.728516| rel error : 0.787357 \n",
      "1000/10000 | loss: 3021.258301 | loss1: 615.738892 | loss2: 2405.519531| rel error : 0.772867 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hwijae/anaconda3/envs/pydata/lib/python3.7/site-packages/torch/serialization.py:402: UserWarning: Couldn't retrieve source code for container of type u_Net. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1100/10000 | loss: 2768.750488 | loss1: 354.721985 | loss2: 2414.028564| rel error : 0.762886 \n",
      "1200/10000 | loss: 2599.811279 | loss1: 448.643036 | loss2: 2151.168213| rel error : 0.729091 \n",
      "1300/10000 | loss: 2395.968994 | loss1: 431.544403 | loss2: 1964.424561| rel error : 0.699210 \n",
      "1400/10000 | loss: 2263.514648 | loss1: 389.800507 | loss2: 1873.714233| rel error : 0.682171 \n",
      "1500/10000 | loss: 2024.515259 | loss1: 302.084961 | loss2: 1722.430298| rel error : 0.655496 \n",
      "1600/10000 | loss: 1838.944336 | loss1: 269.801392 | loss2: 1569.142944| rel error : 0.617557 \n",
      "1700/10000 | loss: 1738.489990 | loss1: 384.536652 | loss2: 1353.953369| rel error : 0.591089 \n",
      "1800/10000 | loss: 1485.008911 | loss1: 162.194504 | loss2: 1322.814453| rel error : 0.575078 \n",
      "1900/10000 | loss: 1327.997925 | loss1: 144.551910 | loss2: 1183.446045| rel error : 0.528142 \n",
      "2000/10000 | loss: 1192.324219 | loss1: 130.755554 | loss2: 1061.568604| rel error : 0.512699 \n",
      "2100/10000 | loss: 1078.166870 | loss1: 93.464134 | loss2: 984.702759| rel error : 0.490407 \n",
      "2200/10000 | loss: 1003.298828 | loss1: 90.195038 | loss2: 913.103821| rel error : 0.476693 \n",
      "2300/10000 | loss: 907.984070 | loss1: 82.521530 | loss2: 825.462524| rel error : 0.459268 \n",
      "2400/10000 | loss: 839.275024 | loss1: 77.913612 | loss2: 761.361389| rel error : 0.430344 \n",
      "2500/10000 | loss: 807.214050 | loss1: 90.554771 | loss2: 716.659302| rel error : 0.408710 \n",
      "2600/10000 | loss: 752.816772 | loss1: 69.631866 | loss2: 683.184937| rel error : 0.398357 \n",
      "2700/10000 | loss: 664.626404 | loss1: 87.089737 | loss2: 577.536682| rel error : 0.374788 \n",
      "2800/10000 | loss: 619.440857 | loss1: 70.595268 | loss2: 548.845581| rel error : 0.362430 \n",
      "2900/10000 | loss: 581.307678 | loss1: 70.928825 | loss2: 510.378845| rel error : 0.348536 \n",
      "3000/10000 | loss: 527.201416 | loss1: 91.112099 | loss2: 436.089294| rel error : 0.331250 \n",
      "3100/10000 | loss: 473.039581 | loss1: 61.468212 | loss2: 411.571381| rel error : 0.310435 \n",
      "3200/10000 | loss: 447.216797 | loss1: 75.206886 | loss2: 372.009918| rel error : 0.302444 \n",
      "3300/10000 | loss: 419.344452 | loss1: 55.450634 | loss2: 363.893829| rel error : 0.285672 \n",
      "3400/10000 | loss: 369.034271 | loss1: 49.367901 | loss2: 319.666382| rel error : 0.263680 \n",
      "3500/10000 | loss: 343.033936 | loss1: 55.606556 | loss2: 287.427368| rel error : 0.261251 \n",
      "3600/10000 | loss: 307.698547 | loss1: 66.781334 | loss2: 240.917221| rel error : 0.236479 \n",
      "3700/10000 | loss: 280.133118 | loss1: 57.045067 | loss2: 223.088058| rel error : 0.226097 \n",
      "3800/10000 | loss: 247.497055 | loss1: 55.083271 | loss2: 192.413788| rel error : 0.216088 \n",
      "3900/10000 | loss: 223.494232 | loss1: 50.141869 | loss2: 173.352371| rel error : 0.207474 \n",
      "4000/10000 | loss: 192.309814 | loss1: 41.886147 | loss2: 150.423660| rel error : 0.176871 \n",
      "4100/10000 | loss: 181.626007 | loss1: 45.077934 | loss2: 136.548080| rel error : 0.170779 \n",
      "4200/10000 | loss: 162.527313 | loss1: 48.973850 | loss2: 113.553459| rel error : 0.165323 \n",
      "4300/10000 | loss: 152.952820 | loss1: 39.891483 | loss2: 113.061333| rel error : 0.147124 \n",
      "4400/10000 | loss: 129.085709 | loss1: 36.798737 | loss2: 92.286972| rel error : 0.140134 \n",
      "4500/10000 | loss: 119.983704 | loss1: 41.682102 | loss2: 78.301598| rel error : 0.128326 \n",
      "4600/10000 | loss: 115.553436 | loss1: 42.439671 | loss2: 73.113770| rel error : 0.116375 \n",
      "4700/10000 | loss: 103.910095 | loss1: 39.204647 | loss2: 64.705444| rel error : 0.110780 \n",
      "4800/10000 | loss: 89.680397 | loss1: 30.224375 | loss2: 59.456020| rel error : 0.097988 \n",
      "4900/10000 | loss: 89.209747 | loss1: 33.394291 | loss2: 55.815456| rel error : 0.097933 \n",
      "5000/10000 | loss: 75.365631 | loss1: 33.962166 | loss2: 41.403461| rel error : 0.088429 \n",
      "5100/10000 | loss: 67.226837 | loss1: 28.013172 | loss2: 39.213661| rel error : 0.080011 \n",
      "5200/10000 | loss: 62.319622 | loss1: 27.670059 | loss2: 34.649563| rel error : 0.078221 \n",
      "5300/10000 | loss: 55.096718 | loss1: 24.101179 | loss2: 30.995539| rel error : 0.072054 \n",
      "5400/10000 | loss: 50.428635 | loss1: 26.860815 | loss2: 23.567822| rel error : 0.068723 \n",
      "5500/10000 | loss: 51.720612 | loss1: 24.042122 | loss2: 27.678488| rel error : 0.068786 \n",
      "5600/10000 | loss: 40.957184 | loss1: 18.317080 | loss2: 22.640106| rel error : 0.063062 \n",
      "5700/10000 | loss: 37.970623 | loss1: 16.591818 | loss2: 21.378805| rel error : 0.059952 \n",
      "5800/10000 | loss: 38.651108 | loss1: 17.767481 | loss2: 20.883627| rel error : 0.059651 \n",
      "5900/10000 | loss: 38.292137 | loss1: 18.712864 | loss2: 19.579271| rel error : 0.056751 \n",
      "6000/10000 | loss: 33.857513 | loss1: 15.715070 | loss2: 18.142443| rel error : 0.057197 \n",
      "6100/10000 | loss: 34.100845 | loss1: 15.998484 | loss2: 18.102364| rel error : 0.055545 \n",
      "6200/10000 | loss: 33.679134 | loss1: 14.508182 | loss2: 19.170954| rel error : 0.054430 \n",
      "6300/10000 | loss: 31.195290 | loss1: 12.268161 | loss2: 18.927130| rel error : 0.054394 \n",
      "6400/10000 | loss: 27.658455 | loss1: 11.237618 | loss2: 16.420837| rel error : 0.052361 \n",
      "6500/10000 | loss: 28.603016 | loss1: 12.587274 | loss2: 16.015743| rel error : 0.053155 \n",
      "6600/10000 | loss: 27.320635 | loss1: 10.809719 | loss2: 16.510916| rel error : 0.052827 \n",
      "6700/10000 | loss: 28.053123 | loss1: 10.715318 | loss2: 17.337807| rel error : 0.052181 \n",
      "6800/10000 | loss: 24.875490 | loss1: 9.959928 | loss2: 14.915563| rel error : 0.051116 \n",
      "6900/10000 | loss: 26.249016 | loss1: 9.356227 | loss2: 16.892790| rel error : 0.050835 \n",
      "7000/10000 | loss: 26.763453 | loss1: 8.863390 | loss2: 17.900063| rel error : 0.050764 \n",
      "7100/10000 | loss: 26.239870 | loss1: 9.898651 | loss2: 16.341219| rel error : 0.050785 \n",
      "7200/10000 | loss: 21.773035 | loss1: 7.102176 | loss2: 14.670859| rel error : 0.049907 \n",
      "7300/10000 | loss: 24.227243 | loss1: 6.361454 | loss2: 17.865789| rel error : 0.049267 \n",
      "7400/10000 | loss: 20.884699 | loss1: 6.235080 | loss2: 14.649618| rel error : 0.049292 \n",
      "7500/10000 | loss: 20.071217 | loss1: 7.048325 | loss2: 13.022891| rel error : 0.048478 \n",
      "7600/10000 | loss: 20.585228 | loss1: 6.112188 | loss2: 14.473041| rel error : 0.047860 \n",
      "7700/10000 | loss: 21.236521 | loss1: 6.807753 | loss2: 14.428767| rel error : 0.047400 \n",
      "7800/10000 | loss: 18.922653 | loss1: 5.891226 | loss2: 13.031426| rel error : 0.047640 \n",
      "7900/10000 | loss: 18.746426 | loss1: 5.668363 | loss2: 13.078063| rel error : 0.046561 \n",
      "8000/10000 | loss: 18.073683 | loss1: 5.481761 | loss2: 12.591921| rel error : 0.045755 \n",
      "8100/10000 | loss: 19.706966 | loss1: 4.948488 | loss2: 14.758479| rel error : 0.044907 \n",
      "8200/10000 | loss: 18.123892 | loss1: 5.071542 | loss2: 13.052350| rel error : 0.044378 \n",
      "8300/10000 | loss: 16.259327 | loss1: 5.306818 | loss2: 10.952510| rel error : 0.043899 \n",
      "8400/10000 | loss: 15.730045 | loss1: 4.789636 | loss2: 10.940409| rel error : 0.042986 \n",
      "8500/10000 | loss: 15.236507 | loss1: 4.981130 | loss2: 10.255378| rel error : 0.042051 \n",
      "8600/10000 | loss: 14.743679 | loss1: 4.780600 | loss2: 9.963079| rel error : 0.041074 \n",
      "8700/10000 | loss: 15.893967 | loss1: 5.206734 | loss2: 10.687233| rel error : 0.040348 \n",
      "8800/10000 | loss: 14.341730 | loss1: 4.406151 | loss2: 9.935579| rel error : 0.039024 \n",
      "8900/10000 | loss: 14.073482 | loss1: 4.398414 | loss2: 9.675068| rel error : 0.038654 \n",
      "9000/10000 | loss: 13.734701 | loss1: 4.593353 | loss2: 9.141348| rel error : 0.037857 \n",
      "9100/10000 | loss: 15.956513 | loss1: 6.775208 | loss2: 9.181305| rel error : 0.037607 \n",
      "9200/10000 | loss: 12.698752 | loss1: 4.607981 | loss2: 8.090772| rel error : 0.036281 \n",
      "9300/10000 | loss: 10.729134 | loss1: 4.106951 | loss2: 6.622183| rel error : 0.035764 \n",
      "9400/10000 | loss: 12.939538 | loss1: 5.015699 | loss2: 7.923839| rel error : 0.035033 \n",
      "9500/10000 | loss: 12.819009 | loss1: 5.184131 | loss2: 7.634878| rel error : 0.035061 \n",
      "9600/10000 | loss: 10.423734 | loss1: 3.570281 | loss2: 6.853453| rel error : 0.033522 \n",
      "9700/10000 | loss: 11.024700 | loss1: 4.100948 | loss2: 6.923752| rel error : 0.033624 \n",
      "9800/10000 | loss: 11.779337 | loss1: 4.603868 | loss2: 7.175469| rel error : 0.032649 \n",
      "9900/10000 | loss: 10.739973 | loss1: 4.580992 | loss2: 6.158981| rel error : 0.031951 \n",
      "10000/10000 | loss: 10.393820 | loss1: 4.031130 | loss2: 6.362690| rel error : 0.031445 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-4)\n",
    "EPOCH=10000\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    # Make dataloader\n",
    "    data = sample(BATCH_SIZE)\n",
    "    data_bdry = sample_boundary(BATCH_SIZE)\n",
    "    u_bdry = boundary_condition(data_bdry)\n",
    "    data_train = TensorDataset(data, data_bdry, u_bdry)\n",
    "    train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)\n",
    "    \n",
    "    loss, loss1, loss2 = train(u_model, trainloader=train_loader, \\\n",
    "                              optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    \n",
    "    output_answer = u_model(answer_x)\n",
    "    err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss: %06.6f | loss1: %06.6f | loss2: %06.6f| rel error : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, loss1, loss2, err))\n",
    "    total_loss.append(loss)\n",
    "    loss1s.append(loss1)\n",
    "    loss2s.append(loss2)\n",
    "    errs.append(err)\n",
    "    # Save Modelview(-1)\n",
    "    if t % 1000 == 0:\n",
    "        torch.save([u_model, total_loss, loss1s, loss2s, errs],\n",
    "                   '../../models/{}_Poisson_L2_L2_{}.pt'.format(dim, xmax))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCH=10000\n",
    "\n",
    "errs_list = []\n",
    "for m in range(100) :\n",
    "    u_model =  u_Net().to(device)\n",
    "    optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-4)\n",
    "    errs = []\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "        # Make dataloader\n",
    "        data = sample(BATCH_SIZE)\n",
    "        data_bdry = sample_boundary(BATCH_SIZE)\n",
    "        u_bdry = boundary_condition(data_bdry)\n",
    "        data_train = TensorDataset(data, data_bdry, u_bdry)\n",
    "        train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "        loss, loss1, loss2 = train(u_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "\n",
    "        output_answer = u_model(answer_x)\n",
    "        err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "        \n",
    "        errs.append(err)\n",
    "    errs_list.append(errs)    \n",
    "    with open('{}_Poisson_L2_L2_{}_error_mean.pt'.format(dim, xmax), 'wb') as f :\n",
    "        pkl.dump(np.array(errs_list), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(errs_list).mean(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
