{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multitask Learning Example\n",
    "\n",
    "from https://github.com/Hui-Li/multi-task-learning-example-PyTorch/blob/master/multi-task-learning-example-PyTorch.ipynb\n",
    "\n",
    "title: Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics\n",
    "arxiv: https://arxiv.org/pdf/1705.07115.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import pylab\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(N):\n",
    "    X = np.random.randn(N, 1)\n",
    "    w1 = 2.\n",
    "    b1 = 8.\n",
    "    sigma1 = 1e1  # ground truth\n",
    "    Y1 = X.dot(w1) + b1 + sigma1 * np.random.randn(N, 1)\n",
    "    w2 = 3\n",
    "    b2 = 3.\n",
    "    sigma2 = 1e0  # ground truth\n",
    "    Y2 = X.dot(w2) + b2 + sigma2 * np.random.randn(N, 1)\n",
    "    return X, Y1, Y2\n",
    "\n",
    "\n",
    "class TrainData(Dataset):\n",
    "\n",
    "    def __init__(self, feature_num, X, Y1, Y2):\n",
    "\n",
    "        self.feature_num = feature_num\n",
    "\n",
    "        self.X = torch.tensor(X, dtype=torch.float32)\n",
    "        self.Y1 = torch.tensor(Y1, dtype=torch.float32)\n",
    "        self.Y2 = torch.tensor(Y2, dtype=torch.float32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.feature_num\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx,:], self.Y1[idx,:], self.Y2[idx,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiTaskLossWrapper(nn.Module):\n",
    "    def __init__(self, task_num, model):\n",
    "        super(MultiTaskLossWrapper, self).__init__()\n",
    "        self.model = model\n",
    "        self.task_num = task_num\n",
    "        self.log_vars = nn.Parameter(torch.zeros((task_num)))\n",
    "\n",
    "    def forward(self, input, targets):\n",
    "\n",
    "        outputs = self.model(input)\n",
    "\n",
    "        precision1 = torch.exp(-self.log_vars[0])\n",
    "        loss = torch.sum(precision1 * (targets[0] - outputs[0]) ** 2. + self.log_vars[0], -1)\n",
    "\n",
    "        precision2 = torch.exp(-self.log_vars[1])\n",
    "        loss += torch.sum(precision2 * (targets[1] - outputs[1]) ** 2. + self.log_vars[1], -1)\n",
    "\n",
    "        loss = torch.mean(loss)\n",
    "        \n",
    "        print(f\"precision1: {precision1}\")\n",
    "        print(f\"precision2: {precision2}\")\n",
    "\n",
    "        return loss, self.log_vars.data.tolist()\n",
    "\n",
    "\n",
    "class MTLModel(torch.nn.Module):\n",
    "    def __init__(self, n_hidden, n_output):\n",
    "        super(MTLModel, self).__init__()\n",
    "\n",
    "        self.net1 = nn.Sequential(nn.Linear(1, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output))\n",
    "        self.net2 = nn.Sequential(nn.Linear(1, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output))\n",
    "\n",
    "    def forward(self, x):\n",
    "        return [self.net1(x), self.net2(x)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x14f653e8cd0>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAABwCAYAAAC0A1S4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAT/0lEQVR4nO2df3Bc1XXHP0erVVhhsGxMg5FtTBmPaQHXqh2GGdykxDEGirHGCXYJATqUUIaUYGDErzDYpLQ2aIoxneEPlzAhjWdAJI4wxNThV9PAjCkWBoMBUyY0IJkkNlgGo8VaaU//eLvS/ng/d9/bt5LuZ0az2vd237tv3/3ee865594nqorBYHCmIe4CGAz1jhGJweCBEYnB4IERicHggRGJweBBYy1PNm3aNJ09e3YtT2kw+KKnp+eAqh5vt6+mIpk9ezY7d+6s5SnHBN27+ujcvpd9/WlObEnRsXQu7W2tcRdrQiEiv3PaV1ORGMrp3tXHbVveIJ0ZBqCvP81tW94AqFgo4010cV+P8UlipnP73hGB5Elnhuncvrei4+VF19efRhkVXfeuvhBKW3vq4XqMSGJmX3860HYvwhZd3NTD9RiRxMyJLalA270IW3RxUw/XY0QSMx1L55JKJoq2pZIJOpbOreh4YYsuburheoxIYqa9rZV1K86gtSWFAK0tKdatOAOAs9c/z8m3/pKz1z9fZoN37+qz3X/OqbZRTGYfNzZFEnYjUglSyyzghQsXqgkBe1Ma8QKrYnxzQSsvvLOfvv40AhTeuVQywboVZ9C5fS99NqaIABtWzR+TUa5aRLdEpEdVF9ruMyIJhzBv5Nnrn3es6G53q7Ulxb5cFMhp/0u3fj2SShd3mLZa3ERixklCIOyxDien1Ks5y1dQO4Hl90c1LhP2MesJ45OEQNhhykqd0nwLLi77owip1kOYNkqMSEIg7DClnbPqVPHz5J3Z9rZWLj1rVtnn8/udyuTU+/ihHsK0UWJEEgJhhyntIl6XnjXLUTj5iFjetLm7/Qw2rJpfFjFrb2t1LJNAxaPY9RCmjRLjuIeAUzSqsOKGdZ5qnePuXX3c8Nhrtv5N3rGvpFy1uP5K8PubGcc9YvI/etTRnfa21qqP2d7WyurHXrPdV6nJVavrD0pYAQUjkpAIowLXioQIwzYWREK8PB9n6vH63QIKRiQGV+wE4rYdohsHiXJ8JayAgqdIRGQm8BPgBCALbFLVjSIyFXgMmA38H7BSVQ8GOvsEop4G21odxlJaHRztqMZB7uh+g807Phjxj5yOW+lv5zRmFDSg4Ce6NQTcpKp/BpwFfE9E/hy4FXhOVecAz+XeG2yo1ZwIp3yuUoLmQwUdB/FTju5dfUUCcTpuNb9dWHlfnj2Jqn4EfJT7/zMReRtoBZYDf5372CPAfwG3BDr7BCEs29iNIK19EEe7e1ef6wh+peXo3L7XMYOg8LheAnW7hrACCoF8EhGZDbQBLwNfzgkIVf1IRP7E4TtXA1cDzJo1K1DhxiJ2pkEtBtuCCtGPo52v8E7YmS1+y+F27ZNTSc/P5cXnJcYwAgq+BxNFZBLwc2C1qn7q93uquklVF6rqwuOPt0/jHi84mQaFN72QMAfbohCiXYXP42S2+C2H27V/Pjg0Yk45fS4hUrNUGF8iEZEklkA2q+qW3OY/iMj03P7pwB9DL90Yw6kVFcGXbezXp7DD76h3kHO4CcxpoNCpHApF5+tYOpdkg33IOTOsI5Xdya9wisRFkQrjKRIREeBHwNuqel/Brq3AFbn/rwCeCL10YwynG3RwIMNRyQZaUsmyNJE8dr3QDY+9xh3dzuZO4XcHBofKtpcKMagT7FThW1tSjiaMXaXOU3i+9rZWJh3lbO3nf0unSWlOkbgoUmH8+CRnA5cBb4jIa7lttwPrgS4R+XvgA+Di0Es3xnBLUz84kCGVTDhOfLLrhRTYvOMDFp401bFS2qWEADQnG2hqbOCGx16jc/teOpbODey3dCyda5tu4hYdKnSW7X6LwvMdHMg4Hqewsjv5FUHLVimePYmqvqiqoqrzVHV+7m+bqn6sqotVdU7u9ZPQSzfGcGtFwd1mdptD4mZnO/kN6UyW/nSmqMcIEqUC51bcyxFub2vlpVu/7pi5nJ/X4rRfwLOyV1q2SjAj7iGSv0E3db0e2Gb2mizlhN8JWk4OOOAYWIDqokNug3luIeBLz5rl65y1SoUxqfIh097WStYlvWNyKmnrOHtNlnIiDBv80y8ykSz2ZuecJxvENSwOVqp/PWFEEgFe4U07x9lrspQTlUzQKiWr7iZdIYEjcKWFyb13CwrUG0YkEeBUcY9uSpAZLu5lCv0Ut8lSTgSdoOWEn9Bp0OhY5/a9ZdebD+/Ww1JBfjEiiQC7irth1XwGBu39gsIK2t7WSsfSuZyYW/mkc/tez9Y67yhvWDUfgJ/u+IAjQ6PnakklbXupQvyYbUFzuNwGFmvpeFeLcdwjws6pdAqLFlbQSjNuS7+XLWjAjwxlWXjSVMASUCl5P8GLoKP6Xlm4NXG8d3fBcz+EQ70weQYsvhPmrQx0CNOT1BA/JkalK4+4pZDkv393+xncv2o+U5pHo1ktqSSdF/+F79TzINtjN6l2d8GT34dDHwJqvT75fWt7AExPUkKU8z78ZKVWmoPld381rXfQwcVIpvXu7oKnb4F0blguNRXOv8e+d3juh5Ap+V0yaWt7gN7EiKSAWiyy5lVJK50o5DbO4uf7fqik0odiUo2YTB+W70t/Alu+awmnVCyHeu2P57TdAbNaSgFOy4tWuopIJbitPALOFbR7Vx8dj79OJlt+P/PLo7bWakZkrlLroV4OMYnhrDKl4XO+SJ1A8/klrbiXz7C7C574HgwPep83mYJlD4x+f8Pp9sJKTYVb3i/aZFZL8Uk9LLLm1FoD3r2cQ/jKa3qsE4FMz6LW3pKlAC18NuL5Nqc/YuiJ66xKN28lPHIRvP/r0WMc+tASRH4/WD2EH4FAuSm1+E7ovhayJTlig4et8vo0ueqiJ6mX+d/10JM44VS2Kc1JmpsaAy0H5Od6PNfSshGFbybPhDnnws4f2e9vOhpu32f9v3ay/+OCVZa1/aNv7zl51H8pLcMNb45+q557kloutuwlxkqyXmuFWxq+WzbtRQ0vcnNjFyfKAfqZhCpMSR+GDTN55ZTrWP3WHNvfo3P7XpYM/5qbm7polQNkERpQa0LEtqOt1nmkhQ/Y0B7qhZ4fO+8f/DxQS1/E5BnF79MOa5ME8EtiF0kt5n9Duc3e15+m4/HXgVExhhWNiaJnbGlOuooBigWxT6fxXHY+Fyf+m2axKvNUDo+aZIc+5PSeO1iQuYo+FpU1Tgs/fYZ1yYdGvpsoFELm86quhdQU+9a9kLzZlJrq/dk8yZRlYhUyeYa9X1IqJhdiF0mt/IC1W/eUObWZrLJ2655Q50RH1TN6WcUXNbzI+oJKPUMOcJk8i8PkPwBSMsjG5INs5EE+0UncNXQ5ndubaG9r5a6m/6AZn75AJUgC1DkzeaSlP/+ecr+iIQntD1r/ewwUvnLKdZzecwcpGb2WtDbx5inX8RWfRY1dJGGtjQTuLXh/2r4VdtpeKVH1jIc8ynlzY9eIQPK4CSRPftHG4+Qw9ycf5KeH34XdB5jMZ5UW1Zv0QVh4pbNPAqMtfb7SO4nBwyRb/dYcFmSuyvWwH7NPj+PeoZX0vDWHly7yV9zYRRKWHxBGCx6GmRRVz3hiS4oFnz5TZE7dO7TSmqDUaPkN1dIg8J3GZ+HpVwNnEgdi8gy4MDcTfOfDlPk0pWbTvJWV+SdYv3sfi9g6uKhouwS4H7GLJCw/wKsFn+Jg0+dTNMIyk8LqGbt39bF26x6+euQFbm7s4jdyAEmOtvwz5AAbkw+CBE+Nd6MB/PsAlZIXwIX3WX8h5Fc5Ecb9iF0kEM6orFcLvmbZaXT87PWi1O1kQliz7DTAv5kUeYRsdxdHnuxg+WA/ywEKhFFKFetbR0tqai6qZONIpaaWC6CKnsKLMCyVuhBJGPjJOAXnHsuPmeSnt3E8T+Il2GCNQv+BaawbvJidxy4Z3VeQdvElCLd7CIgCkkyV5z01HQ1Dg0VOtFJS1GTKcrY/2FFuSuX31ZAwLJW6GEwMg2ofJONnINH3YGNpEp5N5RrQJh4f/irLEi8zRT6LUxNl9HMMLSvuGzGBBlIncG9mFY8cPpMrJv0PNycfozn9e8s0mnMu/O+v7E2lCM2osInsEdUich6wEUgAD6nqerfPR527VY3j7WeEufdntxU5zcCII93QMtOqBB/scI/aFKBafybTkDZwU+YaNv7LOqD+nmIVVXZGJCIRkQTwLrAE6AVeAS5R1becvlO3CY4FCXllplA+BaMkVp//2YoqeaLJf55RnVB4+w8yibWZy+k5dknw3rOEqJ4VH5Vgo0pLORN4T1V/mzvJo1grzTuKpC7JT8zJpBHgBPazselBkEchcQ+w0jKdSpLkbHuAmAQyDCSajrbSOQIwoE3cmrmKrdnR8GgyIXQWOLWVhLSjGlCtVXZGKdXMTGwFCsf7e3PbihCRq0Vkp4js3L9/fxWnC5ndXVYq9ZbvljuokJuncDU8dSMadUi0Cga0iRsHr7USAlf8uzWS7cAgjVZ0CWEgNZ1/kmuKBDKlOUnnt6xZivlVUZzsjHxAxG71lKie6x5XlnY1PYldW1r2m6rqJmATWOZWFefzxstRfOpGK7HOLR2iCM1FaOoPVUZSSXqOXWJtzF3r0BPX0Tj8RdFnDzKJZ0+6kY1/bKPvizTyRfHNSiUTrFl22ohA7JZOLfxsx9K5jj2G0/fCGFANKzsjCNX0JL3AzIL3M4B91RWnCrzmMz91o+VQ+xZIHiWr4XjXbi2Eanl+VunnFWuBh97sNK7PXMuCwU08k/haccx/3koal/8bA6npZBF6s9O4K7ma+9r+kzXvnzZSydyeMOU2X75wVROnHsPpAaXVVua45sxX05O8AswRkZOBPuBvgW+HUqpK8JrP7Jaa7YGIMqiNNEn5yu1+UIXPOYotw4u4PPlCmVCPaIKOzD8AjOQYfdF8As2nXcDAnm0clf49+7LH8VDTdxg6/Vu88M5+9vWnnWcazltJ87yVRc5zw8sfuj44FEZbeqcWX6DIWXeawzKsSiqZCH3KQSRz5n1QsUhUdUhE/hHYjhUCflhV94RWsqB4zWcO3IMUHCL5ZdZ8/k3WJn/CFA4XOe1uYVxV6MuFi7dmF9HakuLyC1aNjKEocIhjuDNzGb/Uv2JYlZ7m0aha964+bnv53NHKNgipnj5f0ZxSU8hLIDDa0vsxa9zWAkuIsG7FGZFU5lKh5Hu/KIVS1Yi7qm4DtoVUlurwmjfglZo9Qsksu2SKjXoJW7NnsvXIooI5G1ZG6XPZ+SxL7CgTT1qbuKUgcjTSks77Ot3DZxdXoOVzecDn4xj8RnPcTCY7konRtbe8Ujm8HhM3rBrZmlq1nKSXpz7W3cpHmta2WK8B10UCLCc9WWLzFmaTLvg772NMngkrNlmviPW67AEeOXzmyEe2ZhexaPAB/vTIZhYNPsDaoSv5yyObuCu5moHU9JHvvbngbnqOXVK2OqHXUqGF0aJKVpkP8pkiCtoFr9UVvQQY5Xq+UUXO3Ig/d6tgnAIYdbghWAqD17yDfGp2z49RHYYSM2lkIo5Nst2J2+wH1BIi/OvK/MJufwPcNbLvK2A7X8HrJrtFh0bK46MSOplMCRFb0yuTVW7qGp2p6dYTuAkwakc6jjBw/D2Jm8MdlHkrrcn9a/ut11KRXXgfrPmERUf9gusz19KbnUZWrQjQLZmrWP3WHNvDOkVVRgXiH7eb7MdE8lsJ3crsFKsbVvX1jHS3h31Gna4SdBXJMIhfJCEtIBaEff3pMrNpa3ZR6E98ssPtJru1hkHP61Zmtwrlx3QJs9EIShxh4PjNrRAm6gelkkGpsBxRN6fYaUHtSpc0ciqzXRkK8TJd4grFxnXu+EWy+M5inwTsV70IkXNOPd52dfVzTg3+nPmgiXxeN7kWSxp5PbbOj+lSq0ex1cO54xeJl8MdAS+8Y59D5rTdiUrDkU43uZatZC1FOdaJXyQQ6fRNO8KKkESRlVrLVjJOs2ksUR8iqTFhJcrVw9rB1RKn2TRWiD+6FQNhRUjiCEcaas+EFElYId3Yn+RkqAkT0tyCcMwMY9NPDCasSMLC2PTjnwlpbhkMQTAiMRg8MCIxGDwwIjEYPDAiMRg8MNEt6ufBpob6ZMKLJI4504axxYQ3t+KYM20YW0x4kYyHJEVDtEx4kZgkRYMXE14kJknR4MWEd9xNkqLBi5o+Dk5E9gO/i/g004Dqn9c8NjHXXjknqartIgc1FUktEJGdTk8sGu+Ya4/m2ie8T2IweGFEYjB4MB5FsinuAsSIufYIGHc+icEQNuOxJzEYQsWIxGDwYFyKREQ6ReQdEdktIr8QkZa4yxQ1InKeiOwVkfdE5Na4y1MrRGSmiLwgIm+LyB4RuT70c4xHn0REzgWezz3X8R4AVb0l5mJFhogkgHeBJVhPRX4FuERV34q1YDVARKYD01X1VRE5BugB2sO89nHZk6jqr1Q1/6jcHViPzx7PnAm8p6q/VdVB4FFgecxlqgmq+pGqvpr7/zPgbSDUnKJxKZISrgSejrsQEdMKFD7kpZeQK8pYQERmA23Ay2Eed8wmOIrIs8AJNrt+oKpP5D7zA2AI2FzLssWA3RPexp8d7YKITAJ+DqxW1U/DPPaYFYmqfsNtv4hcAVwILNbx6HgV0wvMLHg/A9gXU1lqjogksQSyWVW3hH788Vh/ROQ84D7ga6oa7Mk8YxARacRy3BcDfViO+7dVdU+sBasBIiLAI8Anqro6knOMU5G8B3wJ+Di3aYeqXhNjkSJHRC4A7gcSwMOq+s/xlqg2iMgi4DfAG0A2t/l2Vd0W2jnGo0gMhjCZCNEtg6EqjEgMBg+MSAwGD4xIDAYPjEgMBg+MSAwGD4xIDAYP/h88myxk5O89RAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 216x108 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "np.random.seed(0)\n",
    "\n",
    "feature_num = 100\n",
    "nb_epoch = 100\n",
    "batch_size = 20\n",
    "hidden_dim = 1024\n",
    "\n",
    "X, Y1, Y2 = gen_data(feature_num)\n",
    "pylab.figure(figsize=(3, 1.5))\n",
    "pylab.scatter(X[:, 0], Y1[:, 0])\n",
    "pylab.scatter(X[:, 0], Y2[:, 0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision1: 1.0\n",
      "precision2: 1.0\n",
      "precision1: 0.9990004897117615\n",
      "precision2: 0.9990004897117615\n",
      "precision1: 0.9980084896087646\n",
      "precision2: 0.9980095624923706\n",
      "precision1: 0.997039258480072\n",
      "precision2: 0.9970176219940186\n",
      "precision1: 0.9960571527481079\n",
      "precision2: 0.9960400462150574\n",
      "precision1: 0.9950671195983887\n",
      "precision2: 0.9950805902481079\n",
      "precision1: 0.9940718412399292\n",
      "precision2: 0.9941424131393433\n",
      "precision1: 0.9930749535560608\n",
      "precision2: 0.9932420253753662\n",
      "precision1: 0.9921287298202515\n",
      "precision2: 0.9923914074897766\n",
      "precision1: 0.991202175617218\n",
      "precision2: 0.9915634989738464\n",
      "precision1: 0.9902649521827698\n",
      "precision2: 0.9907726645469666\n",
      "precision1: 0.98931884765625\n",
      "precision2: 0.990025520324707\n",
      "precision1: 0.9883562922477722\n",
      "precision2: 0.9893231391906738\n",
      "precision1: 0.9874255061149597\n",
      "precision2: 0.9886627793312073\n",
      "precision1: 0.9865036606788635\n",
      "precision2: 0.98805832862854\n",
      "precision1: 0.985619843006134\n",
      "precision2: 0.9875041842460632\n",
      "precision1: 0.9847414493560791\n",
      "precision2: 0.986994743347168\n",
      "precision1: 0.9838556051254272\n",
      "precision2: 0.9865248203277588\n",
      "precision1: 0.9829981327056885\n",
      "precision2: 0.9861028790473938\n",
      "precision1: 0.9821428060531616\n",
      "precision2: 0.9857221841812134\n",
      "precision1: 0.9812762141227722\n",
      "precision2: 0.9853711724281311\n",
      "precision1: 0.9804201126098633\n",
      "precision2: 0.9850552082061768\n",
      "precision1: 0.9795586466789246\n",
      "precision2: 0.9847646951675415\n",
      "precision1: 0.9787052273750305\n",
      "precision2: 0.9844924807548523\n",
      "precision1: 0.9778727293014526\n",
      "precision2: 0.9842408299446106\n",
      "precision1: 0.9770323038101196\n",
      "precision2: 0.9840144515037537\n",
      "precision1: 0.9761919975280762\n",
      "precision2: 0.9838094115257263\n",
      "precision1: 0.9753555655479431\n",
      "precision2: 0.9836138486862183\n",
      "precision1: 0.974543571472168\n",
      "precision2: 0.9834206700325012\n",
      "precision1: 0.9737430810928345\n",
      "precision2: 0.9832439422607422\n",
      "precision1: 0.9729242920875549\n",
      "precision2: 0.9830742478370667\n",
      "precision1: 0.9720977544784546\n",
      "precision2: 0.9829034805297852\n",
      "precision1: 0.9712821245193481\n",
      "precision2: 0.9827460646629333\n",
      "precision1: 0.9704797863960266\n",
      "precision2: 0.9825954437255859\n",
      "precision1: 0.9696928858757019\n",
      "precision2: 0.9824500679969788\n",
      "precision1: 0.9688925743103027\n",
      "precision2: 0.9823192358016968\n",
      "precision1: 0.9680904746055603\n",
      "precision2: 0.9821878671646118\n",
      "precision1: 0.9672977924346924\n",
      "precision2: 0.9820697903633118\n",
      "precision1: 0.9665073752403259\n",
      "precision2: 0.9819572567939758\n",
      "precision1: 0.9657215476036072\n",
      "precision2: 0.9818539619445801\n",
      "precision1: 0.9649325013160706\n",
      "precision2: 0.9817637205123901\n",
      "precision1: 0.9641324877738953\n",
      "precision2: 0.9816866517066956\n",
      "precision1: 0.9633453488349915\n",
      "precision2: 0.9816203117370605\n",
      "precision1: 0.962554931640625\n",
      "precision2: 0.9815600514411926\n",
      "precision1: 0.9617778062820435\n",
      "precision2: 0.9815055131912231\n",
      "precision1: 0.9609999656677246\n",
      "precision2: 0.9814513921737671\n",
      "precision1: 0.9602221846580505\n",
      "precision2: 0.9814056754112244\n",
      "precision1: 0.9594388008117676\n",
      "precision2: 0.98136305809021\n",
      "precision1: 0.9586625099182129\n",
      "precision2: 0.9813178181648254\n",
      "precision1: 0.9578936100006104\n",
      "precision2: 0.9812847375869751\n",
      "precision1: 0.9571219682693481\n",
      "precision2: 0.9812594652175903\n",
      "precision1: 0.9563301205635071\n",
      "precision2: 0.9812297224998474\n",
      "precision1: 0.9555450677871704\n",
      "precision2: 0.9812020659446716\n",
      "precision1: 0.9547653794288635\n",
      "precision2: 0.9811864495277405\n",
      "precision1: 0.9540036916732788\n",
      "precision2: 0.9811726808547974\n",
      "precision1: 0.9532378315925598\n",
      "precision2: 0.9811647534370422\n",
      "precision1: 0.952467679977417\n",
      "precision2: 0.9811620712280273\n",
      "precision1: 0.951701819896698\n",
      "precision2: 0.9811656475067139\n",
      "precision1: 0.9509038925170898\n",
      "precision2: 0.9811694025993347\n",
      "precision1: 0.9501224756240845\n",
      "precision2: 0.9811766147613525\n",
      "precision1: 0.9493645429611206\n",
      "precision2: 0.9811742305755615\n",
      "precision1: 0.9485747814178467\n",
      "precision2: 0.9811834692955017\n",
      "precision1: 0.9477978944778442\n",
      "precision2: 0.9811908602714539\n",
      "precision1: 0.9470281004905701\n",
      "precision2: 0.9812027812004089\n",
      "precision1: 0.9462645649909973\n",
      "precision2: 0.981209933757782\n",
      "precision1: 0.9455122947692871\n",
      "precision2: 0.9812114238739014\n",
      "precision1: 0.9447827339172363\n",
      "precision2: 0.981209397315979\n",
      "precision1: 0.944044828414917\n",
      "precision2: 0.9812101125717163\n",
      "precision1: 0.9433088898658752\n",
      "precision2: 0.9812135696411133\n",
      "precision1: 0.9425530433654785\n",
      "precision2: 0.9812251329421997\n",
      "precision1: 0.9417905807495117\n",
      "precision2: 0.9812361001968384\n",
      "precision1: 0.9410434365272522\n",
      "precision2: 0.9812443256378174\n",
      "precision1: 0.9402686357498169\n",
      "precision2: 0.981242299079895\n",
      "precision1: 0.9395073652267456\n",
      "precision2: 0.9812474250793457\n",
      "precision1: 0.9387285113334656\n",
      "precision2: 0.9812659621238708\n",
      "precision1: 0.9379688501358032\n",
      "precision2: 0.9812847375869751\n",
      "precision1: 0.9372142553329468\n",
      "precision2: 0.9813023805618286\n",
      "precision1: 0.936488926410675\n",
      "precision2: 0.9813173413276672\n",
      "precision1: 0.9357512593269348\n",
      "precision2: 0.9813421964645386\n",
      "precision1: 0.9349839091300964\n",
      "precision2: 0.9813668131828308\n",
      "precision1: 0.934219479560852\n",
      "precision2: 0.9813878536224365\n",
      "precision1: 0.9334696531295776\n",
      "precision2: 0.9813961982727051\n",
      "precision1: 0.9327079653739929\n",
      "precision2: 0.981406033039093\n",
      "precision1: 0.9319233298301697\n",
      "precision2: 0.9814189076423645\n",
      "precision1: 0.931133508682251\n",
      "precision2: 0.9814382791519165\n",
      "precision1: 0.9303765296936035\n",
      "precision2: 0.9814649820327759\n",
      "precision1: 0.9296256899833679\n",
      "precision2: 0.9814959764480591\n",
      "precision1: 0.9288817048072815\n",
      "precision2: 0.9815270900726318\n",
      "precision1: 0.9281036257743835\n",
      "precision2: 0.9815630316734314\n",
      "precision1: 0.9273263812065125\n",
      "precision2: 0.9815994501113892\n",
      "precision1: 0.9265737533569336\n",
      "precision2: 0.9816234707832336\n",
      "precision1: 0.9258349537849426\n",
      "precision2: 0.9816406965255737\n",
      "precision1: 0.9250981211662292\n",
      "precision2: 0.9816568493843079\n",
      "precision1: 0.9243737459182739\n",
      "precision2: 0.9816787242889404\n",
      "precision1: 0.9236631989479065\n",
      "precision2: 0.9816938042640686\n",
      "precision1: 0.9229008555412292\n",
      "precision2: 0.981722354888916\n",
      "precision1: 0.9221230745315552\n",
      "precision2: 0.9817543029785156\n",
      "precision1: 0.9213362336158752\n",
      "precision2: 0.9817873239517212\n",
      "precision1: 0.9205796718597412\n",
      "precision2: 0.9818300604820251\n",
      "precision1: 0.919834315776825\n",
      "precision2: 0.981855571269989\n",
      "precision1: 0.9190809726715088\n",
      "precision2: 0.9818830490112305\n",
      "precision1: 0.918302059173584\n",
      "precision2: 0.9818968772888184\n",
      "precision1: 0.9175122976303101\n",
      "precision2: 0.9819222688674927\n",
      "precision1: 0.9167529940605164\n",
      "precision2: 0.9819522500038147\n",
      "precision1: 0.9159981608390808\n",
      "precision2: 0.9819883704185486\n",
      "precision1: 0.9152562022209167\n",
      "precision2: 0.982016921043396\n",
      "precision1: 0.9144999980926514\n",
      "precision2: 0.9820477366447449\n",
      "precision1: 0.9137448072433472\n",
      "precision2: 0.98208087682724\n",
      "precision1: 0.9130118489265442\n",
      "precision2: 0.9821112751960754\n",
      "precision1: 0.9122819304466248\n",
      "precision2: 0.9821373820304871\n",
      "precision1: 0.9115389585494995\n",
      "precision2: 0.9821670055389404\n",
      "precision1: 0.9107778072357178\n",
      "precision2: 0.9822014570236206\n",
      "precision1: 0.9100269079208374\n",
      "precision2: 0.9822483062744141\n",
      "precision1: 0.9092615842819214\n",
      "precision2: 0.9822787046432495\n",
      "precision1: 0.9085358381271362\n",
      "precision2: 0.9823161363601685\n",
      "precision1: 0.9077971577644348\n",
      "precision2: 0.9823441505432129\n",
      "precision1: 0.9070661664009094\n",
      "precision2: 0.9823732972145081\n",
      "precision1: 0.9062944054603577\n",
      "precision2: 0.9824029207229614\n",
      "precision1: 0.9055361151695251\n",
      "precision2: 0.982428252696991\n",
      "precision1: 0.9048149585723877\n",
      "precision2: 0.9824551939964294\n",
      "precision1: 0.9040805101394653\n",
      "precision2: 0.9824864268302917\n",
      "precision1: 0.903353750705719\n",
      "precision2: 0.9825131893157959\n",
      "precision1: 0.902603268623352\n",
      "precision2: 0.9825431108474731\n",
      "precision1: 0.9018639326095581\n",
      "precision2: 0.9825769066810608\n",
      "precision1: 0.9011293649673462\n",
      "precision2: 0.9826140999794006\n",
      "precision1: 0.9003956913948059\n",
      "precision2: 0.9826467633247375\n",
      "precision1: 0.8996426463127136\n",
      "precision2: 0.9826739430427551\n",
      "precision1: 0.8988826274871826\n",
      "precision2: 0.9826995134353638\n",
      "precision1: 0.898134708404541\n",
      "precision2: 0.9827193021774292\n",
      "precision1: 0.8973873853683472\n",
      "precision2: 0.9827532172203064\n",
      "precision1: 0.8966622948646545\n",
      "precision2: 0.9827895760536194\n",
      "precision1: 0.8959611058235168\n",
      "precision2: 0.9828251004219055\n",
      "precision1: 0.8952452540397644\n",
      "precision2: 0.982874870300293\n",
      "precision1: 0.8945503830909729\n",
      "precision2: 0.9829114079475403\n",
      "precision1: 0.8938324451446533\n",
      "precision2: 0.9829447269439697\n",
      "precision1: 0.8930976986885071\n",
      "precision2: 0.9829804301261902\n",
      "precision1: 0.892349898815155\n",
      "precision2: 0.9830142259597778\n",
      "precision1: 0.8915953040122986\n",
      "precision2: 0.9830541610717773\n",
      "precision1: 0.890863299369812\n",
      "precision2: 0.9830859303474426\n",
      "precision1: 0.8901352882385254\n",
      "precision2: 0.9831257462501526\n",
      "precision1: 0.8894066214561462\n",
      "precision2: 0.9831603765487671\n",
      "precision1: 0.8886907696723938\n",
      "precision2: 0.9831937551498413\n",
      "precision1: 0.8879879713058472\n",
      "precision2: 0.9832161068916321\n",
      "precision1: 0.8873069882392883\n",
      "precision2: 0.9832453727722168\n",
      "precision1: 0.8866026401519775\n",
      "precision2: 0.9832828640937805\n",
      "precision1: 0.8858652710914612\n",
      "precision2: 0.9833215475082397\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision1: 0.8851442933082581\n",
      "precision2: 0.9833611249923706\n",
      "precision1: 0.8844395279884338\n",
      "precision2: 0.9833981990814209\n",
      "precision1: 0.8837080597877502\n",
      "precision2: 0.9834343791007996\n",
      "precision1: 0.8829765915870667\n",
      "precision2: 0.9834690690040588\n",
      "precision1: 0.8822411298751831\n",
      "precision2: 0.983507513999939\n",
      "precision1: 0.8815180659294128\n",
      "precision2: 0.9835580587387085\n",
      "precision1: 0.8807879686355591\n",
      "precision2: 0.9836093187332153\n",
      "precision1: 0.880073606967926\n",
      "precision2: 0.9836482405662537\n",
      "precision1: 0.8793362379074097\n",
      "precision2: 0.9836751222610474\n",
      "precision1: 0.8786056637763977\n",
      "precision2: 0.9837127923965454\n",
      "precision1: 0.8778829574584961\n",
      "precision2: 0.9837526679039001\n",
      "precision1: 0.8771747946739197\n",
      "precision2: 0.9838045835494995\n",
      "precision1: 0.8764663934707642\n",
      "precision2: 0.9838457107543945\n",
      "precision1: 0.8757494688034058\n",
      "precision2: 0.9838976860046387\n",
      "precision1: 0.8750190734863281\n",
      "precision2: 0.9839337468147278\n",
      "precision1: 0.8742901682853699\n",
      "precision2: 0.983969509601593\n",
      "precision1: 0.8735700845718384\n",
      "precision2: 0.9840161204338074\n",
      "precision1: 0.8728635311126709\n",
      "precision2: 0.9840582013130188\n",
      "precision1: 0.8721540570259094\n",
      "precision2: 0.9840996861457825\n",
      "precision1: 0.8714259266853333\n",
      "precision2: 0.9841359853744507\n",
      "precision1: 0.8707157969474792\n",
      "precision2: 0.9841839075088501\n",
      "precision1: 0.8700118064880371\n",
      "precision2: 0.984226405620575\n",
      "precision1: 0.8692952394485474\n",
      "precision2: 0.9842675924301147\n",
      "precision1: 0.8685765266418457\n",
      "precision2: 0.9843037128448486\n",
      "precision1: 0.8678508996963501\n",
      "precision2: 0.9843413829803467\n",
      "precision1: 0.8671401143074036\n",
      "precision2: 0.9843709468841553\n",
      "precision1: 0.866392970085144\n",
      "precision2: 0.9843864440917969\n",
      "precision1: 0.8656680583953857\n",
      "precision2: 0.9844152927398682\n",
      "precision1: 0.8649446368217468\n",
      "precision2: 0.98445725440979\n",
      "precision1: 0.8642295598983765\n",
      "precision2: 0.9845013618469238\n",
      "precision1: 0.8635178208351135\n",
      "precision2: 0.9845275282859802\n",
      "precision1: 0.862812340259552\n",
      "precision2: 0.9845500588417053\n",
      "precision1: 0.8621074557304382\n",
      "precision2: 0.9845833778381348\n",
      "precision1: 0.8613964319229126\n",
      "precision2: 0.9846258759498596\n",
      "precision1: 0.8606840372085571\n",
      "precision2: 0.9846752285957336\n",
      "precision1: 0.8600046634674072\n",
      "precision2: 0.9847241044044495\n",
      "precision1: 0.8593094348907471\n",
      "precision2: 0.9847710728645325\n",
      "precision1: 0.8586080074310303\n",
      "precision2: 0.9848228693008423\n",
      "precision1: 0.8578957915306091\n",
      "precision2: 0.9848813414573669\n",
      "precision1: 0.8571831583976746\n",
      "precision2: 0.9849279522895813\n",
      "precision1: 0.8564751744270325\n",
      "precision2: 0.9849646091461182\n",
      "precision1: 0.8557884693145752\n",
      "precision2: 0.9850162863731384\n",
      "precision1: 0.8550917506217957\n",
      "precision2: 0.9850579500198364\n",
      "precision1: 0.8544024229049683\n",
      "precision2: 0.9850949048995972\n",
      "precision1: 0.8536902666091919\n",
      "precision2: 0.9851415157318115\n",
      "precision1: 0.8529700636863708\n",
      "precision2: 0.9851871132850647\n",
      "precision1: 0.8522546887397766\n",
      "precision2: 0.9852478504180908\n",
      "precision1: 0.8515465259552002\n",
      "precision2: 0.9853034019470215\n",
      "precision1: 0.8508080244064331\n",
      "precision2: 0.9853622913360596\n",
      "precision1: 0.8501089811325073\n",
      "precision2: 0.9854059219360352\n",
      "precision1: 0.8494102358818054\n",
      "precision2: 0.9854599237442017\n",
      "precision1: 0.8487234115600586\n",
      "precision2: 0.9854986667633057\n",
      "precision1: 0.8480178117752075\n",
      "precision2: 0.9855314493179321\n",
      "precision1: 0.8473061323165894\n",
      "precision2: 0.9855688810348511\n",
      "precision1: 0.846612274646759\n",
      "precision2: 0.9856166243553162\n",
      "precision1: 0.8459091782569885\n",
      "precision2: 0.9856709241867065\n",
      "precision1: 0.8452191352844238\n",
      "precision2: 0.9857253432273865\n",
      "precision1: 0.8445091843605042\n",
      "precision2: 0.9857913255691528\n",
      "precision1: 0.8438087105751038\n",
      "precision2: 0.9858437776565552\n",
      "precision1: 0.8431207537651062\n",
      "precision2: 0.9858877658843994\n",
      "precision1: 0.8424347043037415\n",
      "precision2: 0.9859211444854736\n",
      "precision1: 0.8417243361473083\n",
      "precision2: 0.9859752655029297\n",
      "precision1: 0.8410420417785645\n",
      "precision2: 0.9860266447067261\n",
      "precision1: 0.8403839468955994\n",
      "precision2: 0.9860728979110718\n",
      "precision1: 0.8396969437599182\n",
      "precision2: 0.9861191511154175\n",
      "precision1: 0.8390195965766907\n",
      "precision2: 0.9861574769020081\n",
      "precision1: 0.8383464217185974\n",
      "precision2: 0.986188530921936\n",
      "precision1: 0.8376609683036804\n",
      "precision2: 0.9862171411514282\n",
      "precision1: 0.8369905948638916\n",
      "precision2: 0.9862493872642517\n",
      "precision1: 0.8363025188446045\n",
      "precision2: 0.9863041639328003\n",
      "precision1: 0.8356246948242188\n",
      "precision2: 0.9863565564155579\n",
      "precision1: 0.834932804107666\n",
      "precision2: 0.9863951206207275\n",
      "precision1: 0.8342409729957581\n",
      "precision2: 0.9864333271980286\n",
      "precision1: 0.833554208278656\n",
      "precision2: 0.9864756464958191\n",
      "precision1: 0.8328708410263062\n",
      "precision2: 0.9865338206291199\n",
      "precision1: 0.8321872353553772\n",
      "precision2: 0.9866098761558533\n",
      "precision1: 0.8314926624298096\n",
      "precision2: 0.9866774678230286\n",
      "precision1: 0.8308051228523254\n",
      "precision2: 0.9867336750030518\n",
      "precision1: 0.8301373720169067\n",
      "precision2: 0.9867803454399109\n",
      "precision1: 0.8294586539268494\n",
      "precision2: 0.9868348240852356\n",
      "precision1: 0.8287796378135681\n",
      "precision2: 0.9869041442871094\n",
      "precision1: 0.8280885815620422\n",
      "precision2: 0.9869743585586548\n",
      "precision1: 0.82740318775177\n",
      "precision2: 0.9870293736457825\n",
      "precision1: 0.8267113566398621\n",
      "precision2: 0.9870786666870117\n",
      "precision1: 0.8260396122932434\n",
      "precision2: 0.9871310591697693\n",
      "precision1: 0.8253725171089172\n",
      "precision2: 0.9871764779090881\n",
      "precision1: 0.8247242569923401\n",
      "precision2: 0.9872075915336609\n",
      "precision1: 0.8240622878074646\n",
      "precision2: 0.9872556328773499\n",
      "precision1: 0.823422908782959\n",
      "precision2: 0.9873141050338745\n",
      "precision1: 0.8227461576461792\n",
      "precision2: 0.9873709082603455\n",
      "precision1: 0.8220617771148682\n",
      "precision2: 0.9874415993690491\n",
      "precision1: 0.821367621421814\n",
      "precision2: 0.9875131845474243\n",
      "precision1: 0.8206762075424194\n",
      "precision2: 0.9875744581222534\n",
      "precision1: 0.8200163841247559\n",
      "precision2: 0.9876497387886047\n",
      "precision1: 0.8193483948707581\n",
      "precision2: 0.9876964688301086\n",
      "precision1: 0.8186888098716736\n",
      "precision2: 0.9877399206161499\n",
      "precision1: 0.8180139064788818\n",
      "precision2: 0.9877954721450806\n",
      "precision1: 0.8173485994338989\n",
      "precision2: 0.9878481030464172\n",
      "precision1: 0.8166957497596741\n",
      "precision2: 0.987888514995575\n",
      "precision1: 0.8160292506217957\n",
      "precision2: 0.9879400134086609\n",
      "precision1: 0.8153525590896606\n",
      "precision2: 0.9879955649375916\n",
      "precision1: 0.8146981000900269\n",
      "precision2: 0.988044261932373\n",
      "precision1: 0.8140302896499634\n",
      "precision2: 0.9881008267402649\n",
      "precision1: 0.8133556842803955\n",
      "precision2: 0.9881357550621033\n",
      "precision1: 0.812694251537323\n",
      "precision2: 0.9881911873817444\n",
      "precision1: 0.8120707273483276\n",
      "precision2: 0.9882593154907227\n",
      "precision1: 0.8114485144615173\n",
      "precision2: 0.9883430004119873\n",
      "precision1: 0.8107866644859314\n",
      "precision2: 0.9884307384490967\n",
      "precision1: 0.8101210594177246\n",
      "precision2: 0.9885116219520569\n",
      "precision1: 0.8094562292098999\n",
      "precision2: 0.9885584712028503\n",
      "precision1: 0.8087974786758423\n",
      "precision2: 0.988619327545166\n",
      "precision1: 0.8081237077713013\n",
      "precision2: 0.9886826872825623\n",
      "precision1: 0.8074577450752258\n",
      "precision2: 0.9887530207633972\n",
      "precision1: 0.8067905902862549\n",
      "precision2: 0.988799512386322\n",
      "precision1: 0.8061326146125793\n",
      "precision2: 0.9888476729393005\n",
      "precision1: 0.8054817914962769\n",
      "precision2: 0.9889044165611267\n",
      "precision1: 0.8048258423805237\n",
      "precision2: 0.9889605641365051\n",
      "precision1: 0.8041868209838867\n",
      "precision2: 0.9890137910842896\n",
      "precision1: 0.8035320043563843\n",
      "precision2: 0.9890774488449097\n",
      "precision1: 0.8028761744499207\n",
      "precision2: 0.9891284108161926\n",
      "precision1: 0.8022133708000183\n",
      "precision2: 0.9891878366470337\n",
      "precision1: 0.801544725894928\n",
      "precision2: 0.9892353415489197\n",
      "precision1: 0.8008663654327393\n",
      "precision2: 0.9892818331718445\n",
      "precision1: 0.8002076745033264\n",
      "precision2: 0.9893310070037842\n",
      "precision1: 0.7995622754096985\n",
      "precision2: 0.9893870949745178\n",
      "precision1: 0.7989178895950317\n",
      "precision2: 0.9894484281539917\n",
      "precision1: 0.7982802987098694\n",
      "precision2: 0.9894877672195435\n",
      "precision1: 0.7976529002189636\n",
      "precision2: 0.989531934261322\n",
      "precision1: 0.7970250844955444\n",
      "precision2: 0.9895796179771423\n",
      "precision1: 0.796377420425415\n",
      "precision2: 0.9896410703659058\n",
      "precision1: 0.7957020401954651\n",
      "precision2: 0.989708662033081\n",
      "precision1: 0.7950460910797119\n",
      "precision2: 0.9897804856300354\n",
      "precision1: 0.7943819761276245\n",
      "precision2: 0.9898532629013062\n",
      "precision1: 0.7937262654304504\n",
      "precision2: 0.9899347424507141\n",
      "precision1: 0.7930895090103149\n",
      "precision2: 0.9899911880493164\n",
      "precision1: 0.7924519181251526\n",
      "precision2: 0.9900467395782471\n",
      "precision1: 0.7918233275413513\n",
      "precision2: 0.9901033639907837\n",
      "precision1: 0.7912062406539917\n",
      "precision2: 0.9901406764984131\n",
      "precision1: 0.7905483841896057\n",
      "precision2: 0.9901924133300781\n",
      "precision1: 0.7899124026298523\n",
      "precision2: 0.9902543425559998\n",
      "precision1: 0.7892917394638062\n",
      "precision2: 0.990314781665802\n",
      "precision1: 0.7886663675308228\n",
      "precision2: 0.9904001951217651\n",
      "precision1: 0.7880476713180542\n",
      "precision2: 0.9904789924621582\n",
      "precision1: 0.7874067425727844\n",
      "precision2: 0.9905581474304199\n",
      "precision1: 0.7867693901062012\n",
      "precision2: 0.990613579750061\n",
      "precision1: 0.7861148715019226\n",
      "precision2: 0.9906734824180603\n",
      "precision1: 0.7854811549186707\n",
      "precision2: 0.9907073974609375\n",
      "precision1: 0.784862220287323\n",
      "precision2: 0.990750253200531\n",
      "precision1: 0.7842434048652649\n",
      "precision2: 0.9907991290092468\n",
      "precision1: 0.7836077213287354\n",
      "precision2: 0.9908631443977356\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision1: 0.7829639911651611\n",
      "precision2: 0.9909343123435974\n",
      "precision1: 0.7823233008384705\n",
      "precision2: 0.9910069704055786\n",
      "precision1: 0.781687319278717\n",
      "precision2: 0.9910894632339478\n",
      "precision1: 0.7810598611831665\n",
      "precision2: 0.9911645650863647\n",
      "precision1: 0.7804303765296936\n",
      "precision2: 0.9912243485450745\n",
      "precision1: 0.7798037528991699\n",
      "precision2: 0.9912834167480469\n",
      "precision1: 0.7791709303855896\n",
      "precision2: 0.9913327693939209\n",
      "precision1: 0.7785284519195557\n",
      "precision2: 0.9913761019706726\n",
      "precision1: 0.7778874039649963\n",
      "precision2: 0.9914178252220154\n",
      "precision1: 0.777264416217804\n",
      "precision2: 0.9914847612380981\n",
      "precision1: 0.7766100168228149\n",
      "precision2: 0.9915511608123779\n",
      "precision1: 0.7759723663330078\n",
      "precision2: 0.991595447063446\n",
      "precision1: 0.7753431797027588\n",
      "precision2: 0.9916426539421082\n",
      "precision1: 0.7747346758842468\n",
      "precision2: 0.991701066493988\n",
      "precision1: 0.774118185043335\n",
      "precision2: 0.9917724132537842\n",
      "precision1: 0.7735124230384827\n",
      "precision2: 0.9918582439422607\n",
      "precision1: 0.7729262709617615\n",
      "precision2: 0.9919519424438477\n",
      "precision1: 0.7723216414451599\n",
      "precision2: 0.9920273423194885\n",
      "precision1: 0.7717213034629822\n",
      "precision2: 0.9921038150787354\n",
      "precision1: 0.7710989117622375\n",
      "precision2: 0.9921672344207764\n",
      "precision1: 0.7704752683639526\n",
      "precision2: 0.9922360777854919\n",
      "precision1: 0.7698460221290588\n",
      "precision2: 0.9922991991043091\n",
      "precision1: 0.7692311406135559\n",
      "precision2: 0.9923587441444397\n",
      "precision1: 0.7686218619346619\n",
      "precision2: 0.992412269115448\n",
      "precision1: 0.7680031657218933\n",
      "precision2: 0.9924800992012024\n",
      "precision1: 0.7673848271369934\n",
      "precision2: 0.9925590753555298\n",
      "precision1: 0.766772985458374\n",
      "precision2: 0.992634117603302\n",
      "precision1: 0.7661430239677429\n",
      "precision2: 0.9927096366882324\n",
      "precision1: 0.7655332088470459\n",
      "precision2: 0.9927663207054138\n",
      "precision1: 0.7649195194244385\n",
      "precision2: 0.9928327798843384\n",
      "precision1: 0.7643227577209473\n",
      "precision2: 0.9929021000862122\n",
      "precision1: 0.7637000679969788\n",
      "precision2: 0.9929717183113098\n",
      "precision1: 0.7630952596664429\n",
      "precision2: 0.9930521845817566\n",
      "precision1: 0.7624718546867371\n",
      "precision2: 0.9931123852729797\n",
      "precision1: 0.7618618011474609\n",
      "precision2: 0.9931783080101013\n",
      "precision1: 0.7612583637237549\n",
      "precision2: 0.9932354092597961\n",
      "precision1: 0.7606139779090881\n",
      "precision2: 0.9933075308799744\n",
      "precision1: 0.7599948048591614\n",
      "precision2: 0.9933909773826599\n",
      "precision1: 0.7593857645988464\n",
      "precision2: 0.9934651851654053\n",
      "precision1: 0.7587836980819702\n",
      "precision2: 0.9935300946235657\n",
      "precision1: 0.7581563591957092\n",
      "precision2: 0.9936114549636841\n",
      "precision1: 0.7575289607048035\n",
      "precision2: 0.9936996102333069\n",
      "precision1: 0.7569130659103394\n",
      "precision2: 0.993781566619873\n",
      "precision1: 0.7562994956970215\n",
      "precision2: 0.9938598275184631\n",
      "precision1: 0.7557075023651123\n",
      "precision2: 0.9939191937446594\n",
      "precision1: 0.75513756275177\n",
      "precision2: 0.9939784407615662\n",
      "precision1: 0.7545741200447083\n",
      "precision2: 0.994023859500885\n",
      "precision1: 0.7540037035942078\n",
      "precision2: 0.9940778613090515\n",
      "precision1: 0.753410816192627\n",
      "precision2: 0.9941186308860779\n",
      "precision1: 0.7528101801872253\n",
      "precision2: 0.9941842555999756\n",
      "precision1: 0.752210795879364\n",
      "precision2: 0.9942690134048462\n",
      "precision1: 0.7516045570373535\n",
      "precision2: 0.9943380355834961\n",
      "precision1: 0.7510003447532654\n",
      "precision2: 0.994419276714325\n",
      "precision1: 0.7504020929336548\n",
      "precision2: 0.9944907426834106\n",
      "precision1: 0.7498065829277039\n",
      "precision2: 0.9945536255836487\n",
      "precision1: 0.7491969466209412\n",
      "precision2: 0.9945858716964722\n",
      "precision1: 0.7485749125480652\n",
      "precision2: 0.9946366548538208\n",
      "precision1: 0.7479669451713562\n",
      "precision2: 0.9946922659873962\n",
      "precision1: 0.7473785877227783\n",
      "precision2: 0.9947459101676941\n",
      "precision1: 0.746791422367096\n",
      "precision2: 0.994818925857544\n",
      "precision1: 0.7461920380592346\n",
      "precision2: 0.994917094707489\n",
      "precision1: 0.7455914616584778\n",
      "precision2: 0.9950087070465088\n",
      "precision1: 0.7449963688850403\n",
      "precision2: 0.9950960874557495\n",
      "precision1: 0.7443976402282715\n",
      "precision2: 0.9951568245887756\n",
      "precision1: 0.7438157200813293\n",
      "precision2: 0.9952235221862793\n",
      "precision1: 0.7432553172111511\n",
      "precision2: 0.9952722191810608\n",
      "precision1: 0.7426873445510864\n",
      "precision2: 0.9953321218490601\n",
      "precision1: 0.7420861721038818\n",
      "precision2: 0.9953879117965698\n",
      "precision1: 0.7414945960044861\n",
      "precision2: 0.9954500198364258\n",
      "precision1: 0.7409120798110962\n",
      "precision2: 0.995522141456604\n",
      "precision1: 0.7403226494789124\n",
      "precision2: 0.9956001043319702\n",
      "precision1: 0.7397302389144897\n",
      "precision2: 0.9956430792808533\n",
      "precision1: 0.7391406297683716\n",
      "precision2: 0.995681881904602\n",
      "precision1: 0.7385610342025757\n",
      "precision2: 0.9957411885261536\n",
      "precision1: 0.7379821538925171\n",
      "precision2: 0.995821475982666\n",
      "precision1: 0.7374134063720703\n",
      "precision2: 0.9959079623222351\n",
      "precision1: 0.7368464469909668\n",
      "precision2: 0.9959885478019714\n",
      "precision1: 0.7362691164016724\n",
      "precision2: 0.9960733652114868\n",
      "precision1: 0.7356934547424316\n",
      "precision2: 0.9961345791816711\n",
      "precision1: 0.735112726688385\n",
      "precision2: 0.9962114095687866\n",
      "precision1: 0.7345376014709473\n",
      "precision2: 0.9962751865386963\n",
      "precision1: 0.7339687943458557\n",
      "precision2: 0.9963474273681641\n",
      "precision1: 0.7334122657775879\n",
      "precision2: 0.9964208602905273\n",
      "precision1: 0.732861340045929\n",
      "precision2: 0.9965037107467651\n",
      "precision1: 0.7322772145271301\n",
      "precision2: 0.9965810179710388\n",
      "precision1: 0.7316939234733582\n",
      "precision2: 0.9966686367988586\n",
      "precision1: 0.7311160564422607\n",
      "precision2: 0.9967393279075623\n",
      "precision1: 0.7305297255516052\n",
      "precision2: 0.9968283772468567\n",
      "precision1: 0.7299429178237915\n",
      "precision2: 0.9969101548194885\n",
      "precision1: 0.7293655276298523\n",
      "precision2: 0.9969823956489563\n",
      "precision1: 0.7287907600402832\n",
      "precision2: 0.9970800876617432\n",
      "precision1: 0.7282077074050903\n",
      "precision2: 0.9971411228179932\n",
      "precision1: 0.7276290059089661\n",
      "precision2: 0.9972116351127625\n",
      "precision1: 0.7270533442497253\n",
      "precision2: 0.9973047971725464\n",
      "precision1: 0.7264799475669861\n",
      "precision2: 0.9973760843276978\n",
      "precision1: 0.7259202003479004\n",
      "precision2: 0.9974333047866821\n",
      "precision1: 0.7253546714782715\n",
      "precision2: 0.9974733591079712\n",
      "precision1: 0.7247878313064575\n",
      "precision2: 0.997524619102478\n",
      "precision1: 0.7242200970649719\n",
      "precision2: 0.99759441614151\n",
      "precision1: 0.7236477732658386\n",
      "precision2: 0.9976729154586792\n",
      "precision1: 0.7230715155601501\n",
      "precision2: 0.997766375541687\n",
      "precision1: 0.72249835729599\n",
      "precision2: 0.9978421926498413\n",
      "precision1: 0.7219154834747314\n",
      "precision2: 0.9979342818260193\n",
      "precision1: 0.7213470935821533\n",
      "precision2: 0.998037576675415\n",
      "precision1: 0.7207813858985901\n",
      "precision2: 0.9981074929237366\n",
      "precision1: 0.7202043533325195\n",
      "precision2: 0.9981831312179565\n",
      "precision1: 0.7196171283721924\n",
      "precision2: 0.9982683062553406\n",
      "precision1: 0.7190490961074829\n",
      "precision2: 0.9983572363853455\n",
      "precision1: 0.7184848189353943\n",
      "precision2: 0.9984216690063477\n",
      "precision1: 0.7179263830184937\n",
      "precision2: 0.9984886050224304\n",
      "precision1: 0.7173686027526855\n",
      "precision2: 0.9985557198524475\n",
      "precision1: 0.7168251276016235\n",
      "precision2: 0.9986340403556824\n",
      "precision1: 0.7162637710571289\n",
      "precision2: 0.9987194538116455\n",
      "precision1: 0.7157031893730164\n",
      "precision2: 0.998802661895752\n",
      "precision1: 0.7151450514793396\n",
      "precision2: 0.9988669157028198\n",
      "precision1: 0.7145937085151672\n",
      "precision2: 0.9989460110664368\n",
      "precision1: 0.7140215039253235\n",
      "precision2: 0.9990254640579224\n",
      "precision1: 0.7134491205215454\n",
      "precision2: 0.9991074800491333\n",
      "precision1: 0.7128928899765015\n",
      "precision2: 0.9991713762283325\n",
      "precision1: 0.7123397588729858\n",
      "precision2: 0.9992398023605347\n",
      "precision1: 0.7117906808853149\n",
      "precision2: 0.999316930770874\n",
      "precision1: 0.7112570405006409\n",
      "precision2: 0.9993886947631836\n",
      "precision1: 0.7107245922088623\n",
      "precision2: 0.999451756477356\n",
      "precision1: 0.7101652026176453\n",
      "precision2: 0.9995286464691162\n",
      "precision1: 0.7096102833747864\n",
      "precision2: 0.9995995759963989\n",
      "precision1: 0.709051787853241\n",
      "precision2: 0.9996588826179504\n",
      "precision1: 0.7084969282150269\n",
      "precision2: 0.9997215867042542\n",
      "precision1: 0.7079280614852905\n",
      "precision2: 0.9998084306716919\n",
      "precision1: 0.7073515057563782\n",
      "precision2: 0.999894380569458\n",
      "precision1: 0.7068049907684326\n",
      "precision2: 0.9999645352363586\n",
      "precision1: 0.7062593102455139\n",
      "precision2: 1.0000463724136353\n",
      "precision1: 0.70569908618927\n",
      "precision2: 1.0001205205917358\n",
      "precision1: 0.7051443457603455\n",
      "precision2: 1.000187635421753\n",
      "precision1: 0.7046083211898804\n",
      "precision2: 1.0002413988113403\n",
      "precision1: 0.7040644288063049\n",
      "precision2: 1.0003165006637573\n",
      "precision1: 0.703521192073822\n",
      "precision2: 1.0003920793533325\n",
      "precision1: 0.7029753923416138\n",
      "precision2: 1.0004863739013672\n",
      "precision1: 0.7024205327033997\n",
      "precision2: 1.0005773305892944\n",
      "precision1: 0.7018764615058899\n",
      "precision2: 1.0006539821624756\n",
      "precision1: 0.7013352513313293\n",
      "precision2: 1.000728726387024\n",
      "precision1: 0.7007828950881958\n",
      "precision2: 1.0008023977279663\n",
      "precision1: 0.7002384066581726\n",
      "precision2: 1.0008745193481445\n",
      "precision1: 0.6996833086013794\n",
      "precision2: 1.0009565353393555\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "precision1: 0.6991401314735413\n",
      "precision2: 1.0010231733322144\n",
      "precision1: 0.6986037492752075\n",
      "precision2: 1.0010995864868164\n",
      "precision1: 0.6980698108673096\n",
      "precision2: 1.001198649406433\n",
      "precision1: 0.6975240111351013\n",
      "precision2: 1.0012983083724976\n",
      "precision1: 0.6969711184501648\n",
      "precision2: 1.0013936758041382\n",
      "precision1: 0.696437418460846\n",
      "precision2: 1.0014736652374268\n",
      "precision1: 0.6959038972854614\n",
      "precision2: 1.0015472173690796\n",
      "precision1: 0.6953637599945068\n",
      "precision2: 1.001603126525879\n",
      "precision1: 0.6948279142379761\n",
      "precision2: 1.0016616582870483\n",
      "precision1: 0.6942878365516663\n",
      "precision2: 1.0017337799072266\n",
      "precision1: 0.6937370300292969\n",
      "precision2: 1.0018123388290405\n",
      "precision1: 0.6932076811790466\n",
      "precision2: 1.0018917322158813\n",
      "precision1: 0.6926832795143127\n",
      "precision2: 1.0019667148590088\n",
      "precision1: 0.6921514868736267\n",
      "precision2: 1.0020581483840942\n",
      "precision1: 0.6916290521621704\n",
      "precision2: 1.0021735429763794\n",
      "precision1: 0.6910802721977234\n",
      "precision2: 1.002263069152832\n",
      "precision1: 0.6905521750450134\n",
      "precision2: 1.0023373365402222\n",
      "precision1: 0.6900414824485779\n",
      "precision2: 1.0024088621139526\n",
      "precision1: 0.6895241141319275\n",
      "precision2: 1.0024681091308594\n",
      "precision1: 0.6889995336532593\n",
      "precision2: 1.0025452375411987\n",
      "precision1: 0.6884755492210388\n",
      "precision2: 1.0026084184646606\n",
      "precision1: 0.6879448294639587\n",
      "precision2: 1.0026912689208984\n",
      "precision1: 0.687416136264801\n",
      "precision2: 1.0027925968170166\n",
      "precision1: 0.6868909597396851\n",
      "precision2: 1.0028724670410156\n",
      "precision1: 0.6863591074943542\n",
      "precision2: 1.0029659271240234\n",
      "precision1: 0.685839056968689\n",
      "precision2: 1.0030567646026611\n",
      "precision1: 0.6853099465370178\n",
      "precision2: 1.003137469291687\n",
      "precision1: 0.6847924590110779\n",
      "precision2: 1.0031964778900146\n",
      "precision1: 0.684272825717926\n",
      "precision2: 1.0032427310943604\n",
      "precision1: 0.6837433576583862\n",
      "precision2: 1.003306269645691\n",
      "precision1: 0.6832025051116943\n",
      "precision2: 1.0033906698226929\n",
      "precision1: 0.6826765537261963\n",
      "precision2: 1.0034793615341187\n",
      "precision1: 0.6821463108062744\n",
      "precision2: 1.0035682916641235\n",
      "precision1: 0.6816182732582092\n",
      "precision2: 1.0036466121673584\n",
      "precision1: 0.6810943484306335\n",
      "precision2: 1.0037420988082886\n",
      "precision1: 0.6805766820907593\n",
      "precision2: 1.0038299560546875\n",
      "precision1: 0.6800543069839478\n",
      "precision2: 1.0039176940917969\n",
      "precision1: 0.6795480251312256\n",
      "precision2: 1.0039950609207153\n",
      "precision1: 0.6790581345558167\n",
      "precision2: 1.0040863752365112\n",
      "precision1: 0.67857426404953\n",
      "precision2: 1.004186749458313\n",
      "precision1: 0.6780712604522705\n",
      "precision2: 1.0042860507965088\n",
      "precision1: 0.6775417327880859\n",
      "precision2: 1.004370927810669\n",
      "precision1: 0.6770089268684387\n",
      "precision2: 1.004447102546692\n",
      "precision1: 0.6764841079711914\n",
      "precision2: 1.0045504570007324\n",
      "precision1: 0.6759567856788635\n",
      "precision2: 1.004659652709961\n",
      "precision1: 0.6754302978515625\n",
      "precision2: 1.0047708749771118\n",
      "precision1: 0.6749088168144226\n",
      "precision2: 1.0048500299453735\n",
      "precision1: 0.6743992567062378\n",
      "precision2: 1.004914402961731\n",
      "precision1: 0.6739002466201782\n",
      "precision2: 1.005003571510315\n",
      "precision1: 0.6733972430229187\n",
      "precision2: 1.0050792694091797\n",
      "precision1: 0.6728969216346741\n",
      "precision2: 1.005140781402588\n",
      "precision1: 0.6723737120628357\n",
      "precision2: 1.0052261352539062\n",
      "precision1: 0.6718361973762512\n",
      "precision2: 1.005302906036377\n",
      "precision1: 0.6712925434112549\n",
      "precision2: 1.005362868309021\n",
      "precision1: 0.6707779765129089\n",
      "precision2: 1.0054484605789185\n",
      "precision1: 0.6702606678009033\n",
      "precision2: 1.0055370330810547\n",
      "precision1: 0.6697477102279663\n",
      "precision2: 1.0056250095367432\n",
      "precision1: 0.6692342758178711\n",
      "precision2: 1.0057158470153809\n",
      "precision1: 0.668737530708313\n",
      "precision2: 1.005791425704956\n",
      "precision1: 0.668246328830719\n",
      "precision2: 1.0058695077896118\n",
      "precision1: 0.6677465438842773\n",
      "precision2: 1.0059574842453003\n"
     ]
    }
   ],
   "source": [
    "train_data = TrainData(feature_num, X, Y1, Y2)\n",
    "train_data_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)\n",
    "\n",
    "model = MTLModel(hidden_dim, 1)\n",
    "\n",
    "mtl = MultiTaskLossWrapper(2, model)\n",
    "mtl\n",
    "\n",
    "# https://github.com/keras-team/keras/blob/master/keras/optimizers.py\n",
    "# k.epsilon() = keras.backend.epsilon()\n",
    "optimizer = torch.optim.Adam(mtl.parameters(), lr=0.001, eps=1e-07)\n",
    "\n",
    "loss_list = []\n",
    "for t in range(nb_epoch):\n",
    "    cumulative_loss = 0\n",
    "    \n",
    "    for X, Y1, Y2 in train_data_loader:\n",
    "\n",
    "        loss, log_vars = mtl(X, [Y1, Y2])\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        cumulative_loss += loss.item()\n",
    "\n",
    "    loss_list.append(cumulative_loss/batch_size)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
