{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import math\n",
    "import numpy as np\n",
    "import json\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from collections import defaultdict\n",
    "from make_fn_data import load_fn_data\n",
    "from transformers import BertTokenizer, BertModel, BertForMaskedLM\n",
    "from neural_net import Model, NpClassDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "statistics\n",
      "# lex units:  13572\n",
      "# frames:  1073\n",
      "# data points:  200751\n",
      "# lex units without data:  3271\n"
     ]
    }
   ],
   "source": [
    "# Load and prepare data\n",
    "data = load_fn_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0616 04:16:14.863266 29344 tokenization_utils.py:1075] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at C:\\Users\\danil/.cache\\torch\\transformers\\26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n",
      "I0616 04:16:15.103193 29344 configuration_utils.py:265] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at C:\\Users\\danil/.cache\\torch\\transformers\\4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517\n",
      "I0616 04:16:15.104233 29344 configuration_utils.py:301] Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "I0616 04:16:15.335124 29344 modeling_utils.py:650] loading weights file https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin from cache at C:\\Users\\danil/.cache\\torch\\transformers\\f2ee78bdd635b758cc0a12352586868bef80e47401abe4c4fcc3832421e7338b.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n"
     ]
    }
   ],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "bert_model = BertModel.from_pretrained('bert-base-uncased')\n",
    "bert_model.eval()\n",
    "bert_model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# datapoints =  807\n",
      "max labels =  16\n",
      "17\n"
     ]
    }
   ],
   "source": [
    "# create datapoints from data\n",
    "\n",
    "frame_dict = {}\n",
    "frame_dict_rev = {}\n",
    "\n",
    "inputs = []\n",
    "labels = []\n",
    "\n",
    "for lu in data[:50]:\n",
    "    frame =  lu[\"frame\"]\n",
    "    if not frame in frame_dict.keys():\n",
    "        frame_dict[frame] = len(frame_dict.keys())\n",
    "        frame_dict_rev[frame_dict[frame]] = frame\n",
    "    frame_id = frame_dict[frame]\n",
    "    \n",
    "    for sentence in lu[\"sentences\"]:\n",
    "        text = sentence[\"text\"]\n",
    "        indexes = sentence[\"indexes\"]        \n",
    "        start = min([int(i[0]) for i in indexes])\n",
    "        end = max([int(i[1]) for i in indexes])\n",
    "        inputs.append((text, start, end))\n",
    "        labels.append(frame_id)\n",
    "\n",
    "print(\"# datapoints = \", len(labels))\n",
    "print(\"max labels = \", max(labels))\n",
    "print(len(frame_dict.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "elapsed time (s) to generate one datapoint =  1.2686066627502441\n",
      "dataset in =  tensor([ 0.9563, -0.0470, -0.1990,  ...,  0.0000,  0.0000,  0.0000])\n",
      "dataset out =  tensor(0) torch.LongTensor\n",
      "dimensions: in = 2304  out =  17\n"
     ]
    }
   ],
   "source": [
    "# You should build your custom dataset as below.\n",
    "class FnBertDataset(torch.utils.data.Dataset):\n",
    "    \n",
    "    def __init__(self, inputs, labels, frame_dict, tokenizer, bert_model):\n",
    "        \"\"\"\n",
    "        First two arguments should be lists with the format:\n",
    "        inputs: [(text1, start1, end1), ...]\n",
    "        labels: [label_id1, ...]\n",
    "        \"\"\"\n",
    "        self.inputs = inputs\n",
    "        self.labels = labels\n",
    "        \n",
    "        self.tokenizer = tokenizer\n",
    "        self.bert_model = bert_model\n",
    "        \n",
    "        self.MAX_LEN = 3\n",
    "        self.INPUT_DIM = self.MAX_LEN * self.bert_model.config.hidden_size\n",
    "        self.OUTPUT_DIM = len(frame_dict.keys())\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        text, start, end = self.inputs[index]\n",
    "        x = self.get_bert_hidden_state(text, start, end)\n",
    "        y = torch.tensor(self.labels[index]).long()        \n",
    "        return x, y\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "    \n",
    "    def get_bert_hidden_state(self, text, start, end):\n",
    "        text = \"[CLS] \" + text + \" [SEP]\"\n",
    "        start += len(\"[CLS] \")\n",
    "        end += len(\"[CLS] \")\n",
    "        \n",
    "        # Compute start end end using token indexes\n",
    "        tk_start, tk_end = self.pos_to_token_idx(text, start, end)\n",
    "        tk_end = min(tk_start + self.MAX_LEN, tk_end)\n",
    "        # Tokenize input\n",
    "        tokenized_text = self.tokenizer.tokenize(text)\n",
    "    \n",
    "        # Convert token to vocabulary indices\n",
    "        indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)\n",
    "        # Convert inputs to PyTorch tensors\n",
    "        tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')\n",
    "        # Predict hidden states features for each layer\n",
    "        with torch.no_grad():\n",
    "            outputs = self.bert_model(tokens_tensor)\n",
    "            # Hidden state of the last layer of the Bert model\n",
    "            hidden = torch.squeeze(outputs[0], dim = 0)\n",
    "            # Slice hidden state to hidden[start:end]\n",
    "            hidden = hidden.narrow(0, tk_start, tk_end-tk_start)\n",
    "            # Add padding\n",
    "            pad = torch.zeros(self.MAX_LEN, hidden.size()[1])            \n",
    "            pad[0:hidden.size()[0],:] = hidden\n",
    "            hidden = torch.flatten(pad)\n",
    "            return hidden\n",
    "\n",
    "    def pos_to_token_idx(self, text, start, end):\n",
    "        target_prefix = self.tokenizer.tokenize(text[:start])\n",
    "        target = self.tokenizer.tokenize(text[start:end+1])\n",
    "        tk_start = len(target_prefix)\n",
    "        tk_end = tk_start + len(target)\n",
    "        return tk_start, tk_end\n",
    "    \n",
    "# dataset = FnBertDataset([inputs[0], inputs[-1]], [labels[0], labels[-1]], frame_dict, tokenizer, bert_model)\n",
    "dataset = FnBertDataset(inputs, labels, frame_dict, tokenizer, bert_model)\n",
    "\n",
    "start_time = time.time()\n",
    "for i in range(10):\n",
    "    dataset[i]\n",
    "print(\"elapsed time (s) to generate one datapoint = \", time.time() - start_time)\n",
    "\n",
    "print(\"dataset in = \", dataset[0][0])\n",
    "print(\"dataset out = \", dataset[0][1], dataset[0][1].type())\n",
    "print(\"dimensions: in =\", dataset.INPUT_DIM, \" out = \", dataset.OUTPUT_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1,     2] loss: 1.539\n",
      "[ 1,     4] loss: 1.182\n",
      "Epoch  1 finished. Loss: 1.361 (elapsed 4.389s)\n",
      "[ 2,     2] loss: 0.915\n",
      "[ 2,     4] loss: 0.829\n",
      "Epoch  2 finished. Loss: 0.872 (elapsed 4.700s)\n",
      "[ 3,     2] loss: 0.582\n",
      "[ 3,     4] loss: 0.334\n",
      "Epoch  3 finished. Loss: 0.458 (elapsed 4.703s)\n",
      "[ 4,     2] loss: 0.354\n",
      "[ 4,     4] loss: 0.157\n",
      "Epoch  4 finished. Loss: 0.256 (elapsed 4.486s)\n",
      "[ 5,     2] loss: 0.184\n",
      "[ 5,     4] loss: 0.072\n",
      "Epoch  5 finished. Loss: 0.128 (elapsed 4.573s)\n",
      "[ 6,     2] loss: 0.068\n",
      "[ 6,     4] loss: 0.140\n",
      "Epoch  6 finished. Loss: 0.104 (elapsed 5.466s)\n",
      "[ 7,     2] loss: 0.079\n",
      "[ 7,     4] loss: 0.027\n",
      "Epoch  7 finished. Loss: 0.053 (elapsed 6.068s)\n",
      "[ 8,     2] loss: 0.092\n",
      "[ 8,     4] loss: 0.016\n",
      "Epoch  8 finished. Loss: 0.054 (elapsed 5.767s)\n",
      "[ 9,     2] loss: 0.025\n",
      "[ 9,     4] loss: 0.042\n",
      "Epoch  9 finished. Loss: 0.033 (elapsed 6.049s)\n",
      "[10,     2] loss: 0.047\n",
      "[10,     4] loss: 0.016\n",
      "Epoch 10 finished. Loss: 0.032 (elapsed 5.841s)\n",
      "Training finished (elapsed 52.041s)\n"
     ]
    }
   ],
   "source": [
    "def create_net(input_dim, output_dim):\n",
    "    layers = [\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(input_dim, 100),\n",
    "        nn.ReLU(),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(100, output_dim),    \n",
    "    ]\n",
    "    model = nn.Sequential(*layers)\n",
    "    return model\n",
    "\n",
    "# Run training & testing\n",
    "net = create_net(input_dim = dataset.INPUT_DIM, output_dim = dataset.OUTPUT_DIM)\n",
    "model = Model(net, criterion = nn.CrossEntropyLoss(),\n",
    "              optimizer=optim.Adam(net.parameters(), lr=10e-4))\n",
    "model.fit(dataset, n_epochs=10, batch_size=32, verbose=True, print_every=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.save(\n",
    "    net.state_dict(), 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\state_dict_small')\n",
    "torch.save(\n",
    "    net, 'C:\\\\Users\\\\danil\\\\Documents\\\\Northwestern\\\\Research\\\\projects\\\\frame_classification\\\\net_small')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_dataset = FnBertDataset(inputs, labels, frame_dict, tokenizer, bert_model)\n",
    "# dev_dataset = FnBertDataset([inputs[0], inputs[-1]], [labels[0], labels[-1]], \n",
    "#                             frame_dict, tokenizer, bert_model)\n",
    "model.net.eval()\n",
    "model.test(dev_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(frame_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
