{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metal CDR Relation Extraction "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import metal\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch:  1.0.0\n",
      "MeTaL:    0.3.3\n",
      "Python:   3.6.7 (default, Dec  8 2018, 17:35:14) \n",
      "[GCC 5.4.0 20160609]\n",
      "Python:   sys.version_info(major=3, minor=6, micro=7, releaselevel='final', serial=0)\n"
     ]
    }
   ],
   "source": [
    "print('PyTorch: ', torch.__version__)\n",
    "print('MeTaL:   ', metal.__version__)\n",
    "print('Python:  ', sys.version)\n",
    "print('Python:  ', sys.version_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initalize CDR Dataset\n",
    "To uncompress the SQLite db: ```bzip2 -d cdr.db.bz2```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db\n",
      "Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db\n",
      "Connected to sqlite:////dfs/scratch0/vschen/metal/metal/contrib/backends/cdr.db\n",
      "[TRAIN] 8272\n",
      "[DEV]   888\n",
      "[TEST]  4620\n"
     ]
    }
   ],
   "source": [
    "from metal.contrib.backends.wrapper import SnorkelDataset\n",
    "import os\n",
    "\n",
    "db_conn_str   = os.path.join(os.getcwd(),\"cdr.db\")\n",
    "candidate_def = ['ChemicalDisease', ['chemical', 'disease']]\n",
    "\n",
    "train, dev, test = SnorkelDataset.splits(db_conn_str, \n",
    "                                         candidate_def, \n",
    "                                         max_seq_len=125)\n",
    "\n",
    "print(f'[TRAIN] {len(train)}')\n",
    "print(f'[DEV]   {len(dev)}')\n",
    "print(f'[TEST]  {len(test)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train End Model (Random Initalized Embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using randomly initialized embeddings.\n",
      "Embeddings shape = (9946, 50)\n",
      "The embeddings are NOT FROZEN\n",
      "Using lstm_reduction = 'attention'\n",
      "\n",
      "Network architecture:\n",
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (0): LSTMModule(\n",
      "      (embeddings): Embedding(9946, 50)\n",
      "      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)\n",
      "    )\n",
      "    (1): ReLU()\n",
      "  )\n",
      "  (1): Linear(in_features=200, out_features=2, bias=True)\n",
      ")\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from metal.end_model import EndModel\n",
    "from metal.contrib.modules import LSTMModule\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "lstm = LSTMModule(embed_size=50, \n",
    "                  hidden_size=100, \n",
    "                  vocab_size=train.word_dict.len(),\n",
    "                  lstm_reduction='attention', \n",
    "                  dropout=0,\n",
    "                  num_layers=1, \n",
    "                  freeze=False)\n",
    "\n",
    "end_model = EndModel([200, 2], input_module=lstm, seed=123, device=device)\n",
    "\n",
    "end_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01\n",
    "end_model.config['train_config']['validation_metric'] = 'f1'\n",
    "end_model.config['train_config']['batch_size'] = 32\n",
    "end_model.config['train_config']['n_epochs'] = 5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GPU...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "adf2da0a02a84056a7800fee6cd30d99",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Saving model at iteration 0 with best score 0.580\n",
      "[E:0]\tTrain Loss: 0.564\tDev f1: 0.580\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a749686c77f4c9db903b44284f654b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:1]\tTrain Loss: 0.328\tDev f1: 0.549\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2dc9a89ade3467ea7b2faa44029bfc4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Saving model at iteration 2 with best score 0.594\n",
      "[E:2]\tTrain Loss: 0.191\tDev f1: 0.594\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c479132d61da4853bd91d6f5c2c382a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:3]\tTrain Loss: 0.118\tDev f1: 0.571\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9b53c31330d54f5c88a99f84d2a0e4c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:4]\tTrain Loss: 0.092\tDev f1: 0.589\n",
      "Restoring best model from iteration 2 with score 0.594\n",
      "Finished Training\n",
      "F1: 0.594\n",
      "        y=1    y=2   \n",
      " l=1    183    137   \n",
      " l=2    113    455   \n"
     ]
    }
   ],
   "source": [
    "end_model.train_model(train, valid_data=dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precision: 0.517\n",
      "Recall: 0.660\n",
      "F1: 0.580\n",
      "        y=1    y=2   \n",
      " l=1    993    927   \n",
      " l=2    512   2188   \n"
     ]
    }
   ],
   "source": [
    "score = end_model.score(test, metric=['precision', 'recall', 'f1'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train End Model (Pretrained Embeddings)\n",
    "\n",
    "Download [GloVe embeddings](http://nlp.stanford.edu/data/glove.6B.zip):\n",
    "`wget http://nlp.stanford.edu/data/glove.6B.zip \\\n",
    "&& mkdir -p glove.6B \\\n",
    "&& unzip glove.6B.zip -d glove.6B \\\n",
    "&& rm glove.6B.zip`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "import numpy as np\n",
    "\n",
    "import torch.nn.init as init\n",
    "\n",
    "class EmbeddingLoader(object):\n",
    "    \"\"\"\n",
    "    Simple text file embedding loader. Words with GloVe and FastText.\n",
    "    \"\"\"\n",
    "    def __init__(self, fpath, fmt='text', dim=None, normalize=True):\n",
    "        assert os.path.exists(fpath)\n",
    "        self.fpath = fpath\n",
    "        self.dim = dim\n",
    "        self.fmt = fmt\n",
    "        # infer dimension\n",
    "        if not self.dim:\n",
    "            header = open(self.fpath, \"rU\").readline().strip().split(' ')\n",
    "            self.dim = len(header) - 1 if len(header) != 2 else int(header[-1])\n",
    "\n",
    "        self.vocab, self.vectors = zip(*[(w,vec) for w,vec in self._read()])\n",
    "        self.vocab = {w:i for i,w in enumerate(self.vocab)}\n",
    "        self.vectors = np.vstack(self.vectors)\n",
    "        if normalize:\n",
    "            self.vectors = (self.vectors.T / np.linalg.norm(self.vectors, axis=1, ord=2)).T\n",
    "\n",
    "    def _read(self):\n",
    "        start = 0 if self.fmt == \"text\" else 1\n",
    "        for i, line in enumerate(open(self.fpath, \"rU\")):\n",
    "            if i < start:\n",
    "                continue\n",
    "            line = line.rstrip().split(' ')\n",
    "            vec = np.array([float(x) for x in line[1:]])\n",
    "            if len(vec) != self.dim:\n",
    "                errors += [line[0]]\n",
    "                continue\n",
    "            yield (line[0], vec)\n",
    "            \n",
    "\n",
    "def load_embeddings(vocab, embeddings):\n",
    "    \"\"\"\n",
    "    Load pretrained embeddings\n",
    "    \"\"\"\n",
    "    def get_word_match(w, word_dict):\n",
    "        if w in word_dict:\n",
    "            return word_dict[w]\n",
    "        elif w.lower() in word_dict:\n",
    "            return word_dict[w.lower()]\n",
    "        elif w.strip(string.punctuation) in word_dict:\n",
    "            return word_dict[w.strip(string.punctuation)]\n",
    "        elif w.strip(string.punctuation).lower() in word_dict:\n",
    "            return word_dict[w.strip(string.punctuation).lower()]\n",
    "        else:\n",
    "            return -1\n",
    "\n",
    "    num_words = vocab.len()\n",
    "    emb_dim   = embeddings.vectors.shape[1]\n",
    "    vecs      = init.xavier_normal_(torch.empty(num_words, emb_dim))\n",
    "    vecs[0]   = torch.zeros(emb_dim)\n",
    "\n",
    "    n = 0\n",
    "    for w in vocab.d:\n",
    "        idx = get_word_match(w, embeddings.vocab)\n",
    "        if idx == -1:\n",
    "            continue\n",
    "        i = vocab.lookup(w)\n",
    "        vecs[i] = torch.FloatTensor(embeddings.vectors[idx])\n",
    "        n += 1\n",
    "\n",
    "    print(\"Loaded {:2.1f}% ({}/{}) pretrained embeddings\".format(float(n) / vocab.len() * 100.0, n, vocab.len() ))\n",
    "    return vecs         "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dfs/scratch0/vschen/snorkel-pytorch/venv/lib/python3.6/site-packages/ipykernel_launcher.py:17: DeprecationWarning: 'U' mode is deprecated\n",
      "/dfs/scratch0/vschen/snorkel-pytorch/venv/lib/python3.6/site-packages/ipykernel_launcher.py:28: DeprecationWarning: 'U' mode is deprecated\n"
     ]
    }
   ],
   "source": [
    "emb_path  = \"glove.6B/glove.6B.50d.txt\"\n",
    "embs  = EmbeddingLoader(emb_path, fmt='text')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Connected to sqlite:///cdr.db\n",
      "Connected to sqlite:///cdr.db\n",
      "Connected to sqlite:///cdr.db\n",
      "[TRAIN] 8272\n",
      "[DEV]   888\n",
      "[TEST]  4620\n"
     ]
    }
   ],
   "source": [
    "from metal.contrib.backends.wrapper import SnorkelDataset\n",
    "\n",
    "db_conn_str   = \"cdr.db\"\n",
    "candidate_def = ['ChemicalDisease', ['chemical', 'disease']]\n",
    "\n",
    "train, dev, test = SnorkelDataset.splits(db_conn_str, \n",
    "                                         candidate_def, \n",
    "                                         pretrained_word_dict=embs.vocab, \n",
    "                                         max_seq_len=125)\n",
    "\n",
    "print(f'[TRAIN] {len(train)}')\n",
    "print(f'[DEV]   {len(dev)}')\n",
    "print(f'[TEST]  {len(test)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initalize pretrained embedding matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 79.9% (9116/11406) pretrained embeddings\n"
     ]
    }
   ],
   "source": [
    "wembs = load_embeddings(train.word_dict, embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using pretrained embeddings.\n",
      "Embeddings shape = (11406, 50)\n",
      "The embeddings are NOT FROZEN\n",
      "Using lstm_reduction = 'attention'\n",
      "\n",
      "Network architecture:\n",
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (0): LSTMModule(\n",
      "      (embeddings): Embedding(11406, 50)\n",
      "      (lstm): LSTM(50, 100, batch_first=True, bidirectional=True)\n",
      "    )\n",
      "    (1): ReLU()\n",
      "  )\n",
      "  (1): Linear(in_features=200, out_features=2, bias=True)\n",
      ")\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from metal.end_model import EndModel\n",
    "from metal.contrib.modules import LSTMModule\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "lstm = LSTMModule(embed_size=50, \n",
    "                  hidden_size=100, \n",
    "                  embeddings=wembs,\n",
    "                  lstm_reduction='attention', \n",
    "                  dropout=0, \n",
    "                  num_layers=1, \n",
    "                  freeze=False)\n",
    "\n",
    "end_model = EndModel([200, 2], input_module=lstm, seed=123, device=device)\n",
    "\n",
    "end_model.config['train_config']['optimizer_config']['optimizer_common']['lr'] = 0.01\n",
    "end_model.config['train_config']['validation_metric'] = 'f1'\n",
    "end_model.config['train_config']['batch_size'] = 32\n",
    "end_model.config['train_config']['n_epochs'] = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using GPU...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e3949b032c04ba89ef1aaa95288f08c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Saving model at iteration 0 with best score 0.599\n",
      "[E:0]\tTrain Loss: 0.533\tDev f1: 0.599\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95ebf8e4082c4369aeba9d92c3d1967e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:1]\tTrain Loss: 0.293\tDev f1: 0.568\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2de2e35c0e624437bcfa294afba33f4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:2]\tTrain Loss: 0.174\tDev f1: 0.560\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4e525b5df764685bcb0a19fc75c5dae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:3]\tTrain Loss: 0.119\tDev f1: 0.585\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95467e1bb48040cfa52d157722508c3e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=259), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[E:4]\tTrain Loss: 0.078\tDev f1: 0.575\n",
      "Restoring best model from iteration 0 with score 0.599\n",
      "Finished Training\n",
      "F1: 0.599\n",
      "        y=1    y=2   \n",
      " l=1    191    151   \n",
      " l=2    105    441   \n"
     ]
    }
   ],
   "source": [
    "end_model.train_model(train, valid_data=dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Precision: 0.526\n",
      "Recall: 0.658\n",
      "F1: 0.585\n",
      "        y=1    y=2   \n",
      " l=1    990    892   \n",
      " l=2    515   2223   \n"
     ]
    }
   ],
   "source": [
    "score = end_model.score(test, metric=['precision', 'recall', 'f1'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
