{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchmetrics as metrics\n",
    "\n",
    "from minicons import cwe\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_wic(file = \"train\"):\n",
    "    row = [x.strip().split(\"\\t\") for x in open(f\"../data/WiC_dataset/{file}/{file}.data.txt\", \"r\").readlines()]\n",
    "    if not file == \"test\":\n",
    "        gold = [x.strip() for x in open(f\"../data/WiC_dataset/{file}/{file}.gold.txt\", \"r\").readlines()]\n",
    "    dataset = []\n",
    "    for i, data in enumerate(row):\n",
    "        word, pos, idx, sentence1, sentence2 = data\n",
    "        idx1, idx2 = idx.split('-')\n",
    "        idx1, idx2 = int(idx1), int(idx2)\n",
    "        \n",
    "        context1 = [sentence1, idx1]\n",
    "        context2 = [sentence2, idx2]\n",
    "        \n",
    "        if not file == \"test\":\n",
    "            label = gold[i]\n",
    "            dataset.append((context1, context2, pos, label))\n",
    "        else:\n",
    "            dataset.append((context1, context2, pos))\n",
    "            \n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = metrics.Accuracy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wic_model import WiCModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = WiCModel(hparams = {'input_size': 1536, 'hidden_size': 512, 'dropout': 0.5, 'hidden_layers': 2, 'lr': 1e-3, 'approximator': 0, 'layer': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c1 = bert.extract_representation([['You must carry your camping gear .', (2, 3)], ['Messages must go through diplomatic channels .', (2, 3)]])\n",
    "c2 = bert.extract_representation([['Sound carries well over water .', (1, 2)], ['Do you think the sofa will go through the door ?', (6, 7)]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.sigmoid(model(torch.cat((c1, c2), dim = 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc(torch.tensor([1, 0, 1]), torch.tensor([1, 0, 1])).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from approximation_model import NonLinearApproximator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "approximator = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/12/version_2048_2_0-0001.ckpt')\n",
    "approximator.eval()\n",
    "\n",
    "for param in approximator.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = approximator(bert.extract_representation([\"I went to close the door.\", (3, 4)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding.requires_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat = torch.rand(3, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat.argmax(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_hat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.randint(2, (3,), dtype=torch.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.nn.functional.cross_entropy(y_hat, torch.randint(2, (3,), dtype=torch.int64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = load_wic()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DataLoader(train, batch_size = 50, num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = []\n",
    "labels = []\n",
    "for batch in tqdm(train_dl):\n",
    "    x, y = model._build_batch(batch, approximator=0)\n",
    "    embeddings.extend(x)\n",
    "    labels.extend(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_embeddings = torch.stack(embeddings)\n",
    "labels = torch.stack(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_embeddings.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(layer_embeddings, f'{auth1_path}/makesense_data/original/layer_12_embeddings.pt')\n",
    "torch.save(labels, f'{auth1_path}/makesense_data/labels.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paired = list(zip(layer_embeddings, labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paired_dl = DataLoader(paired, batch_size = 64, num_workers = 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = model._build_batch(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.training_step(batch, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label2id = {\n",
    "    'T': 1,\n",
    "    'F': 0\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.tensor(list(map(lambda x: label2id[x], batch[3])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "What we need:\n",
    "\n",
    "The input to the models will be a concatenated tensor with size bs x 1536 (where bs is the batch size.)\n",
    "To get that tensor, we need to pass the sentence and the indices to the model.extract_representation method.\n",
    "\n",
    "For cases where we are using approximations, we will need to run the approximator on the loaded vectors.\n",
    "\n",
    "\n",
    "Another Todo: Writing a DataModule for WiC:\n",
    "(sentences1), tensor of indices1, \n",
    "'''"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}