{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import os\n",
    "import math\n",
    "import numpy as np\n",
    "import json\n",
    "import time\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from make_fn_data import load_fn_data\n",
    "from neural_net import Model, NpClassDataset\n",
    "from transformers import BertTokenizer, BertModel, BertForMaskedLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "bert_model = BertModel.from_pretrained('bert-base-uncased')\n",
    "bert_model.eval()\n",
    "bert_model.to('cuda')\n",
    "print(bert_model.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and prepare data\n",
    "data = load_fn_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create datapoints from data\n",
    "\n",
    "frame_dict = {}\n",
    "frame_dict_rev = {}\n",
    "\n",
    "inputs = []\n",
    "labels = []\n",
    "\n",
    "for lu in data:\n",
    "    frame =  lu[\"frame\"]\n",
    "    if not frame in frame_dict.keys():\n",
    "        frame_dict[frame] = len(frame_dict.keys())\n",
    "        frame_dict_rev[frame_dict[frame]] = frame\n",
    "    frame_id = frame_dict[frame]\n",
    "    \n",
    "    for sentence in lu[\"sentences\"]:\n",
    "        text = sentence[\"text\"]\n",
    "        indexes = sentence[\"indexes\"]\n",
    "        if len(indexes) > 0:\n",
    "            start = min([int(i[0]) for i in indexes])\n",
    "            end = max([int(i[1]) for i in indexes])\n",
    "            inputs.append((text, start, end))\n",
    "            labels.append(frame_id)\n",
    "        \n",
    "print(\"# datapoints = \", len(labels))\n",
    "print(\"max labels = \", max(labels))\n",
    "print(len(frame_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You should build your custom dataset as below.\n",
    "class FnBertDataset(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, inputs, labels, frame_dict, tokenizer, bert_model):\n",
    "        \"\"\"\n",
    "        First two arguments should be lists with the format:\n",
    "        inputs: [(text1, start1, end1), ...]\n",
    "        labels: [label_id1, ...]\n",
    "        \"\"\"\n",
    "        self.inputs = inputs\n",
    "        self.labels = labels\n",
    "        \n",
    "        self.tokenizer = tokenizer\n",
    "        self.bert_model = bert_model\n",
    "        \n",
    "        self.MAX_LEN = 4\n",
    "        self.INPUT_DIM = self.MAX_LEN * self.bert_model.config.hidden_size\n",
    "        self.OUTPUT_DIM = len(frame_dict.keys())\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        text, start, end = self.inputs[index]\n",
    "        x = self.get_bert_hidden_state(text, start, end)\n",
    "        y = torch.tensor(self.labels[index]).long()        \n",
    "        return x, y\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "    \n",
    "    def get_bert_hidden_state(self, text, start, end):\n",
    "        text = \"[CLS] \" + text + \" [SEP]\"\n",
    "        start += len(\"[CLS] \")\n",
    "        end += len(\"[CLS] \")\n",
    "        \n",
    "        # Compute start end end using token indexes\n",
    "        tk_start, tk_end = self.pos_to_token_idx(text, start, end)\n",
    "        tk_end = min(tk_start + self.MAX_LEN, tk_end)\n",
    "        # Tokenize input\n",
    "        tokenized_text = self.tokenizer.tokenize(text)\n",
    "    \n",
    "        # Convert token to vocabulary indices\n",
    "        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "        # Convert inputs to PyTorch tensors\n",
    "        tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')\n",
    "        # Predict hidden states features for each layer\n",
    "        with torch.no_grad():\n",
    "            outputs = self.bert_model(tokens_tensor)\n",
    "            # Hidden state of the last layer of the Bert model\n",
    "            hidden = torch.squeeze(outputs[0], dim = 0)\n",
    "            # Slice hidden state to hidden[start:end]\n",
    "            hidden = hidden.narrow(0, tk_start, tk_end-tk_start)\n",
    "            # Add padding\n",
    "            pad = torch.zeros(self.MAX_LEN, hidden.size()[1])            \n",
    "            pad[0:hidden.size()[0],:] = hidden\n",
    "            hidden = torch.flatten(pad)\n",
    "            return hidden\n",
    "\n",
    "    def pos_to_token_idx(self, text, start, end):\n",
    "        target_prefix = self.tokenizer.tokenize(text[:start])\n",
    "        target = self.tokenizer.tokenize(text[start:end+1])\n",
    "        tk_start = len(target_prefix)\n",
    "        tk_end = tk_start + len(target)\n",
    "        return tk_start, tk_end\n",
    "    \n",
    "dataset = FnBertDataset(inputs, labels, frame_dict, tokenizer, bert_model)\n",
    "print(\"dataset in = \", dataset[100][0])\n",
    "print(\"dataset out = \", dataset[100][1], dataset[100][1].type())\n",
    "print(\"dimensions: in =\", dataset.INPUT_DIM, \" out = \", dataset.OUTPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def create_net(input_dim, output_dim):\n",
    "    layers = [\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(input_dim, 400),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(400, output_dim),\n",
    "    ]\n",
    "    model = nn.Sequential(*layers)\n",
    "    return model\n",
    "\n",
    "# Run training & testing\n",
    "net = create_net(input_dim = dataset.INPUT_DIM, output_dim = dataset.OUTPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load saved network\n",
    "net.load_state_dict(torch.load('C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\state_dict_3.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.train()\n",
    "model = Model(net, criterion = nn.CrossEntropyLoss(),\n",
    "              optimizer=optim.Adam(net.parameters(), lr=10e-5))\n",
    "model.fit(dataset, n_epochs=10, batch_size=32, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.save(\n",
    "    net.state_dict(), 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\state_dict_4.pth')\n",
    "torch.save(\n",
    "    net, 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\net_4.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_idxs = random.choices(range(len(inputs)), k=1000)\n",
    "dev_inputs = [inputs[idx] for idx in dev_idxs]\n",
    "dev_labels = [labels[idx] for idx in dev_idxs]\n",
    "\n",
    "net.eval()\n",
    "dev_dataset = FnBertDataset(dev_inputs, dev_labels, frame_dict, tokenizer, bert_model)\n",
    "print(\"length of dev set: \", len(dev_dataset))\n",
    "model.test(dev_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def predict_top_k_dataset(dataset, k, batch_size=1):\n",
    "        predicted_lst = []\n",
    "        probs_lst = []\n",
    "        data_loader = torch.utils.data.DataLoader(\n",
    "            dataset=dataset, batch_size=batch_size, shuffle=False)    \n",
    "        with torch.no_grad():\n",
    "            for (inputs, _) in data_loader:\n",
    "                inputs = inputs.to(\"cuda\")\n",
    "                predicted, probs = predict_top_k(inputs, k)\n",
    "                predicted_lst.append(predicted)\n",
    "                probs_lst.append(probs)\n",
    "        predicted_tensor = torch.cat(predicted_lst, 0)\n",
    "        probs_tensor = torch.cat(probs_lst, 0)\n",
    "        return predicted_tensor, probs_tensor\n",
    "    \n",
    "def predict_top_k(inputs, k, batch_size=1):\n",
    "    inputs = inputs.to(\"cuda\")\n",
    "    with torch.no_grad():\n",
    "        outputs = net(inputs)\n",
    "        logits, predicted = torch.topk(outputs.data, k, dim = 1)\n",
    "        softmax = nn.Softmax(dim=1)\n",
    "        probs = softmax(logits)\n",
    "        return predicted, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_in = [\n",
    "    (\"the problem is telling which is the original document and which the copy\", 68, 71),\n",
    "    (\"the cause of the accident is not clear\", 4, 8),\n",
    "    (\"Rubella, also known as German measles or three-day measles, is an infection caused by the rubella virus.\", 0, 6),\n",
    "    (\"he died after a long illness\", 21, 27),\n",
    "    (\"for a time revolution was a strong probability\", 35, 45),\n",
    "]\n",
    "dev_lab = [\n",
    "    frame_dict[\"Duplication\"], frame_dict[\"Causation\"], \n",
    "    frame_dict[\"Medical_conditions\"], frame_dict[\"Medical_conditions\"],\n",
    "    frame_dict[\"Probability\"]\n",
    "]\n",
    "dev_dataset = FnBertDataset(dev_in, dev_lab, frame_dict, tokenizer, bert_model)\n",
    "preds, probs = predict_top_k_dataset(dev_dataset, 5)\n",
    "preds = preds.tolist()\n",
    "probs = probs.tolist()\n",
    "for pred, prob in zip(preds, probs):\n",
    "    print([(frame_dict_rev[x], round(y, 2)) for x, y in zip(pred, prob)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
