{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T05:49:56.886747Z",
     "start_time": "2020-05-05T05:49:55.552862Z"
    }
   },
   "outputs": [],
   "source": [
    "# import modules\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle as pkl\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# deep learning modules\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Plot modules\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.267254Z",
     "start_time": "2020-05-05T06:17:55.260453Z"
    }
   },
   "outputs": [],
   "source": [
    "# Use Gpu\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:55.441753Z",
     "start_time": "2020-05-05T06:17:55.436955Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "X_train = torch.FloatTensor(np.linspace(-1,1, 100)).to(device).view(-1,1)\n",
    "y_train = (X_train>0).float()*X_train\n",
    "\n",
    "X_test = torch.FloatTensor(np.linspace(-1,1, 51)).to(device).view(-1,1)\n",
    "y_test = (X_test>0).float()*X_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.245996Z",
     "start_time": "2020-05-05T06:17:58.241527Z"
    }
   },
   "outputs": [],
   "source": [
    "# Neural network, Weight Sharing\n",
    "\n",
    "class u_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(u_Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(1, 256)\n",
    "        self.fc2 = nn.Linear(256, 256)    \n",
    "        self.fc3 = nn.Linear(256, 256)    \n",
    "        self.fc4 = nn.Linear(256, 256)    \n",
    "        self.fc5 = nn.Linear(256, 1)\n",
    "        self.act1 = nn.Tanh()\n",
    "    def forward(self, x):\n",
    "        x = self.act1(self.fc1(x))\n",
    "        x = self.act1(self.fc2(x))\n",
    "        x = self.act1(self.fc3(x))\n",
    "        x = self.act1(self.fc4(x))\n",
    "        x = self.fc5(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.248577Z",
     "start_time": "2020-05-05T06:17:58.246944Z"
    }
   },
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "EPOCH = 1000000\n",
    "BATCH_SIZE = 2601"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.251546Z",
     "start_time": "2020-05-05T06:17:58.249443Z"
    }
   },
   "outputs": [],
   "source": [
    "# Make dataloader\n",
    "data_train = TensorDataset(X_train, y_train)\n",
    "train_loader = DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.262409Z",
     "start_time": "2020-05-05T06:17:58.256729Z"
    }
   },
   "outputs": [],
   "source": [
    "# Training function\n",
    "\n",
    "def train(uv_model, trainloader, optimizer, loss_f) :\n",
    "    uv_model.train()\n",
    "    loss_list, err_list  = [], []\n",
    "    \n",
    "    for i, (data, y) in enumerate(trainloader) :\n",
    "        optimizer.zero_grad()\n",
    "        output = uv_model(data)  \n",
    "        loss = loss_f(output, y)\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        err_list.append(loss_f(uv_model(X_test), y_test).item())\n",
    "        loss_list.append((loss).item())\n",
    "    return np.mean(loss_list), np.mean(err_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T06:17:58.271852Z",
     "start_time": "2020-05-05T06:17:58.263364Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_results(u_model, show) :\n",
    "    prediction = (u_model(X_train)).cpu().detach().numpy()\n",
    "    fig = plt.figure(figsize=(10,5))\n",
    "    plt.plot(prediction)    \n",
    "    if show :\n",
    "        plt.show()\n",
    "    else :\n",
    "        return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:34.873952Z",
     "start_time": "2020-05-05T08:22:34.849481Z"
    }
   },
   "outputs": [],
   "source": [
    "u_model =  u_Net()      \n",
    "total_loss = []\n",
    "loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "u_model = u_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T08:22:57.203176Z",
     "start_time": "2020-05-05T08:22:35.054371Z"
    }
   },
   "outputs": [],
   "source": [
    "optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-3)\n",
    "\n",
    "EPOCH=100000\n",
    "for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "    \n",
    "    loss, err = train(u_model, trainloader=train_loader, \\\n",
    "                                  optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "    \n",
    "    # Print Log\n",
    "    if t%1000 == 0 :\n",
    "        print(\"%s/%s | loss: %06.6f | err : %06.6f \" % \\\n",
    "              (t, EPOCH, loss, err))\n",
    "        \n",
    "    if t%10000 ==0 :\n",
    "        plot_results(u_model, show=True)\n",
    "        \n",
    "    total_loss.append(loss)\n",
    "    # Save Modelview(-1)\n",
    "    if t % 1000 == 0:\n",
    "        torch.save([u_model, total_loss],\n",
    "                   '../../models/ReLU.pt')\n",
    "        \n",
    "    if err<1e-5 :\n",
    "        print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "continue_ = False\n",
    "\n",
    "if continue_ :\n",
    "    epochs, times, errs_list = pd.read_pickle('ReLU_L2_1e-4_test_err')\n",
    "else :\n",
    "    epochs = []\n",
    "    times = []\n",
    "    errs_list = []\n",
    "    \n",
    "for i in tqdm_notebook(range(100-len(epochs))) :\n",
    "    u_model =  u_Net().to(device)\n",
    "    total_loss = []\n",
    "    loss1s, loss2s, loss3s, errs = [], [], [], []\n",
    "    optimizer=torch.optim.Adam([{'params': u_model.parameters()}], lr=1e-5)\n",
    "\n",
    "    EPOCH=10000\n",
    "    a = time.time()\n",
    "    saved=False\n",
    "    for t in tqdm_notebook(range(1, EPOCH+1)) :\n",
    "\n",
    "        loss, err = train(u_model, trainloader=train_loader, \\\n",
    "                                      optimizer=optimizer, loss_f=nn.MSELoss())\n",
    "        # Print Log\n",
    "        errs.append(err)\n",
    "        if (err<1e-4) & (saved==False) :\n",
    "            print('In EPOCH {}, total loss is under 1e-5, training finished'.format(t))\n",
    "            epochs.append(t)\n",
    "            b = time.time()\n",
    "            times.append(b-a)\n",
    "            saved=True\n",
    "            \n",
    "    if saved==False :\n",
    "        epochs.append(EPOCH)\n",
    "        \n",
    "    errs_list.append(errs)\n",
    "    with open('ReLU_L2_1e-4_test_err', 'wb') as f:\n",
    "        pkl.dump([epochs, times, errs_list], f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pydata",
   "language": "python",
   "name": "pydata"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
