{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:31:57.043573Z",
     "start_time": "2020-08-06T00:31:55.641974Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:31:57.063322Z",
     "start_time": "2020-08-06T00:31:57.044829Z"
    }
   },
   "outputs": [],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:31:57.069013Z",
     "start_time": "2020-08-06T00:31:57.064483Z"
    }
   },
   "outputs": [],
   "source": [
    "dim = 100\n",
    "xmin, xmax = 0,1\n",
    "\n",
    "try :\n",
    "    test_data = pd.read_pickle('poisson_{}_test'.format(dim))\n",
    "except :\n",
    "    test_data = np.random.uniform(0,1,[1000,dim])\n",
    "    with open('poisson_{}_test'.format(dim), 'wb') as f :\n",
    "        pkl.dump(test_data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:31:57.104172Z",
     "start_time": "2020-08-06T00:31:57.070016Z"
    }
   },
   "outputs": [],
   "source": [
    "def analytic(data) :\n",
    "    return torch.sin((np.pi/2) * data).sum(dim=1).view(-1,1)\n",
    "\n",
    "def boundary_condition(data) : #data : N X dim\n",
    "    return torch.sin((np.pi/2) * data).sum(dim=1).view(-1,1)\n",
    "\n",
    "def sample(N):\n",
    "    return torch.rand([N,dim])\n",
    "\n",
    "def sample_boundary(N):\n",
    "    N = int(N/2)\n",
    "    interior = torch.rand([N,dim])\n",
    "    ind = int(N/dim)\n",
    "    for i in range(dim) :\n",
    "        interior[i*ind:(i+1)*ind, i] = 0\n",
    "    interior[(ind)*dim:, -1] = 0\n",
    "    \n",
    "    interior2 = torch.rand([N,dim])\n",
    "    ind = int(N/dim)\n",
    "    for i in range(dim) :\n",
    "        interior2[i*ind:(i+1)*ind, i] = 1\n",
    "    interior2[ind*dim:, -1] = 1\n",
    "    \n",
    "    return torch.cat([interior, interior2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:32:00.723326Z",
     "start_time": "2020-08-06T00:31:57.105054Z"
    }
   },
   "outputs": [],
   "source": [
    "answer_x = torch.FloatTensor(test_data).to(device)\n",
    "answer_y = analytic(answer_x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:32:00.731342Z",
     "start_time": "2020-08-06T00:32:00.725834Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class u_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(u_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(dim, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)    \n",
    "        self.fc3 = nn.Linear(256, 256)   \n",
    "        self.fc6 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.fc6(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:32:00.735129Z",
     "start_time": "2020-08-06T00:32:00.732857Z"
    }
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-06T00:32:00.740654Z",
     "start_time": "2020-08-06T00:32:00.737536Z"
    }
   },
   "outputs": [],
   "source": [
    "def calculate_derivative(y, x) :\n",
    "    return torch.autograd.grad(y, x, create_graph=True,\\\n",
    "                        grad_outputs=torch.ones(y.size()).to(device))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T09:02:46.324355Z",
     "start_time": "2020-09-01T09:02:46.315923Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(uv_model, trainloader, optimizer, loss_f) :\n",
    "    uv_model.train()\n",
    "    loss_list, loss_list1, loss_list2, err_list = [], [], [], []\n",
    "    \n",
    "    for i, (X_v, X_bdry, u_bdry) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        X_v = Variable(X_v, requires_grad=True).to(device)\n",
    "        X_bdry, u_bdry = Variable(X_bdry, requires_grad=True).to(device), u_bdry.to(device)\n",
    "        mask = torch.where(X_bdry==0, X_bdry, torch.ones_like(X_bdry))\n",
    "        \n",
    "        output = uv_model(X_v)  \n",
    "        output_bdry = uv_model(X_bdry)\n",
    "        \n",
    "        del_u = calculate_derivative(output, X_v)\n",
    "        u_xx = sum(calculate_derivative(del_u[:,j], X_v)[:,j] for j in range(dim)).view(-1,1)\n",
    "        f = u_xx + ((np.pi**2)/4)*torch.sin((np.pi/2)*X_v).sum(dim=1).view(-1,1)\n",
    "        del_f = calculate_derivative(f, X_v)\n",
    "        del_bdry = calculate_derivative(output_bdry, X_bdry)\n",
    "        \n",
    "        loss1 = loss_f(f, torch.zeros_like(u_xx)) + loss_f(del_f, torch.zeros_like(del_f))#((u_t-u_xx)**2).mean()#\n",
    "        loss2 = loss_f(output_bdry, u_bdry) + loss_f(del_bdry*mask, (np.pi/2) * torch.cos((np.pi/2) * X_bdry)*mask)\n",
    "        loss = loss1 + loss2\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        #err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "        loss_list.append((loss).item())\n",
    "        loss_list1.append(loss1.item())\n",
    "        loss_list2.append(loss2.item())\n",
    "        #err_list.append(err)\n",
    "    return np.mean(loss_list), np.mean(loss_list1), np.mean(loss_list2)#, np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T09:02:46.779856Z",
     "start_time": "2020-09-01T09:02:46.770006Z"
    }
   },
   "outputs": [],
   "source": [
    "u_model =  u_Net()      \n",
    "total_loss = []\n",
    "loss1s, loss2s, errs = [], [], []\n",
    "u_model = u_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T09:02:47.005267Z",
     "start_time": "2020-09-01T09:02:47.001738Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_results(output_answer) :\n",
    "    fig = plt.figure(figsize=[8,4])\n",
    "    plt.plot(answer_x[:,1].cpu().detach().numpy(), answer_y.cpu().detach().numpy(), 'ro', label='True')\n",
    "    plt.plot(answer_x[:,1].cpu().detach().numpy(), output_answer.cpu().detach().numpy(), 'bo', label='Pred')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-01T09:02:35.238917Z",
     "start_time": "2020-09-01T09:02:32.390749Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-5)\n",
    "EPOCH=100000\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    # Make dataloader\n",
    "    data = sample(BATCH_SIZE)\n",
    "    data_bdry = sample_boundary(BATCH_SIZE)\n",
    "    u_bdry = boundary_condition(data_bdry)\n",
    "    data_train = TensorDataset(data, data_bdry, u_bdry)\n",
    "    train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)\n",
    "    \n",
    "    loss, loss1, loss2 = train(u_model, trainloader=train_loader, \\\n",
    "                              optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    \n",
    "    output_answer = u_model(answer_x)\n",
    "    err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "    # Print Log\n",
    "    if t%100 == 0 :\n",
    "        print(\"%s/%s | loss: %06.6f | loss1: %06.6f | loss2: %06.6f| rel error : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, loss1, loss2, err))\n",
    "    total_loss.append(loss)\n",
    "    loss1s.append(loss1)\n",
    "    loss2s.append(loss2)\n",
    "    errs.append(err)\n",
    "    # Save Modelview(-1)\n",
    "    if t % 1000 == 0:\n",
    "        torch.save([u_model, total_loss, loss1s, loss2s, errs],\n",
    "                   '../../models/{}_Poisson_H1_H1_{}.pt'.format(dim, xmax))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-05T21:06:15.354957Z",
     "start_time": "2020-09-01T09:03:30.732677Z"
    }
   },
   "outputs": [],
   "source": [
    "EPOCH=10000\n",
    "\n",
    "errs_list = []\n",
    "for m in tqdm_notebook(range(100)) :\n",
    "    u_model =  u_Net().to(device)\n",
    "    optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-4)\n",
    "    errs = []\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "        # Make dataloader\n",
    "        data = sample(BATCH_SIZE)\n",
    "        data_bdry = sample_boundary(BATCH_SIZE)\n",
    "        u_bdry = boundary_condition(data_bdry)\n",
    "        data_train = TensorDataset(data, data_bdry, u_bdry)\n",
    "        train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "        loss, loss1, loss2 = train(u_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "\n",
    "        output_answer = u_model(answer_x)\n",
    "        err = (torch.abs(answer_y - output_answer)/answer_y).mean().item()\n",
    "        \n",
    "        errs.append(err)\n",
    "    errs_list.append(errs)    \n",
    "    with open('{}_Poisson_H1_H1_{}_error_mean.pt'.format(dim, xmax), 'wb') as f :\n",
    "        pkl.dump(np.array(errs_list), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-05T21:06:15.393493Z",
     "start_time": "2020-09-05T21:06:15.356388Z"
    }
   },
   "outputs": [],
   "source": [
    "np.array(errs_list).mean(axis=0) #remove direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-07T03:21:12.144323Z",
     "start_time": "2020-09-07T03:21:12.104559Z"
    }
   },
   "outputs": [],
   "source": [
    "np.array(errs_list).mean(axis=0) #original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
