{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from minicons import cwe\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import csv\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sat Apr 10 17:52:49 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 460.39       Driver Version: 460.39       CUDA Version: 11.2     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-PCIE...  Off  | 00000000:3B:00.0 Off |                    0 |\n",
      "| N/A   27C    P0    24W / 250W |      8MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-PCIE...  Off  | 00000000:D8:00.0 Off |                    0 |\n",
      "| N/A   27C    P0    24W / 250W |      8MiB / 32510MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A      1994      G   /usr/lib/xorg/Xorg                  4MiB |\n",
      "|    1   N/A  N/A      1994      G   /usr/lib/xorg/Xorg                  4MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert = cwe.CWE(\"bert-base-uncased\", \"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "sense_data = []\n",
    "words = defaultdict(lambda: len(words))\n",
    "sents = defaultdict(lambda: len(sents))\n",
    "\n",
    "with open(\"../sent_data/all_sent_data.pkl\", \"rb\") as f:\n",
    "    semeval = pickle.load(f)\n",
    "\n",
    "for line in semeval:\n",
    "    dataset_id, sent_id, sentence, word, position, sense = line\n",
    "    sent_id = sents[sentence]\n",
    "    word_id = words[word]\n",
    "    if sense == None:\n",
    "        pass\n",
    "    else:\n",
    "        sense_data.append((dataset_id, sent_id, sentence, word, position, sense))\n",
    "\n",
    "with open(\"../data/multi_sense.csv\", \"r\") as f:\n",
    "    reader = csv.DictReader(f)\n",
    "    for line in reader:\n",
    "        if line['pos'] != 'NNP' and line['pos'] != 'RB':\n",
    "            if line['context'][-1] != \".\":\n",
    "                sentence = line['context'] + \" .\"\n",
    "            else:\n",
    "                sentence = line['context']\n",
    "            sent_id = sents[sentence]\n",
    "            word_id = words[line['word']]\n",
    "            position = int(line['index'])\n",
    "            sense_data.append((\"semcor\", sent_id, sentence , line['word'], position, line['lex_sense']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "149731"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sense_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../data/sense_metadata.csv\", \"w\") as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow([\"dataset_id\", \"sent_id\", \"sentence\", \"word\", \"position\", \"sense\"])\n",
    "    writer.writerows(sense_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sense_dl = DataLoader(sense_data, batch_size = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1498/1498 [03:06<00:00,  8.02it/s]\n"
     ]
    }
   ],
   "source": [
    "embedding_data = []\n",
    "for batch in tqdm(sense_dl):\n",
    "    _, _, sentence, word, position, sense = batch\n",
    "    position_tuple = [(i, i+1) for i in position.tolist()]\n",
    "    representation = bert.extract_representation(list(zip(sentence, position_tuple)))\n",
    "    embedding_data.extend(list(zip(position.tolist(), word, sense, representation.detach().cpu())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "149731"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(embedding_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}