{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from minicons import cwe\n",
    "\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "import csv\n",
    "\n",
    "from approximation_model import NonLinearApproximator\n",
    "\n",
    "from paths import auth1_path\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert = cwe.CWE('bert-base-uncased', 'cuda:1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_wic(file = \"train\"):\n",
    "    row = [x.strip().split(\"\\t\") for x in open(f\"../data/WiC_dataset/{file}/{file}.data.txt\", \"r\").readlines()]\n",
    "    if not file == \"test\":\n",
    "        gold = [x.strip() for x in open(f\"../data/WiC_dataset/{file}/{file}.gold.txt\", \"r\").readlines()]\n",
    "    dataset = []\n",
    "    for i, data in enumerate(row):\n",
    "        word, pos, idx, sentence1, sentence2 = data\n",
    "        idx1, idx2 = idx.split('-')\n",
    "        idx1, idx2 = int(idx1), int(idx2)\n",
    "        \n",
    "        context1 = [sentence1, idx1]\n",
    "        context2 = [sentence2, idx2]\n",
    "        \n",
    "        if not file == \"test\":\n",
    "            label = gold[i]\n",
    "            dataset.append((context1, context2, pos, label))\n",
    "        else:\n",
    "            dataset.append((context1, context2, pos))\n",
    "            \n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "validation = load_wic('dev')\n",
    "val_dl = DataLoader(validation, num_workers = 4, batch_size = 128)\n",
    "\n",
    "test = load_wic('test')\n",
    "test_dl = DataLoader(test, num_workers = 4, batch_size = 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "label2id = {\n",
    "            'T': 1,\n",
    "            'F': 0\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def threshold_classifier(x1, x2, threshold):\n",
    "    cosines = torch.cosine_similarity(x1, x2)\n",
    "    y_hat = [1 if c >= threshold else 0 for c in cosines.tolist()]\n",
    "    \n",
    "    return y_hat\n",
    "\n",
    "def accuracy(y_hat, y):\n",
    "    y_hat = torch.tensor(y_hat)\n",
    "    y = torch.tensor(y)\n",
    "    return (y_hat == y).float().mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold_space = torch.linspace(0.00, 1.00, steps = 50).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 0:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:10<00:00,  3.82s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 1:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:17<00:00,  3.94s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 2:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:17<00:00,  3.94s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 3:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:14<00:00,  3.88s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 4:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:13<00:00,  3.88s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 5:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:13<00:00,  3.87s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 6:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [03:03<00:00,  3.67s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 7:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.23s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 8:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.22s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 9:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.21s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 10:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.23s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 11:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.21s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 12:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.23s/it]\n"
     ]
    }
   ],
   "source": [
    "layer_stats = []\n",
    "\n",
    "for layer in range(bert.layers+1):\n",
    "    best = 0\n",
    "    best_threshold = 0\n",
    "    print(f\"Computing stats on WiC dev set for Layer {layer}:\")\n",
    "    for threshold in tqdm(threshold_space):\n",
    "        target = []\n",
    "        predicted = []\n",
    "        for batch in val_dl:\n",
    "            context1, context2, pos, labels = batch\n",
    "            context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "            context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "            context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "            labels = list(map(lambda x: label2id[x], labels))\n",
    "\n",
    "            c1 = bert.extract_representation(context1, layer)\n",
    "            c2 = bert.extract_representation(context2, layer)\n",
    "\n",
    "            predicted.extend(threshold_classifier(c1, c2, threshold))\n",
    "            target.extend(labels)\n",
    "        acc = accuracy(predicted, target)\n",
    "        if acc > best:\n",
    "            best = acc\n",
    "            best_threshold = threshold\n",
    "    layer_stats.append((layer, best, best_threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 0.5579937100410461, 0.9591836929321289),\n",
       " (1, 0.5783699154853821, 0.8775510787963867),\n",
       " (2, 0.6081504821777344, 0.8163265585899353),\n",
       " (3, 0.6191222667694092, 0.7755101919174194),\n",
       " (4, 0.6394984126091003, 0.7346938848495483),\n",
       " (5, 0.6551724076271057, 0.7142857313156128),\n",
       " (6, 0.6645768284797668, 0.6734693646430969),\n",
       " (7, 0.6661441922187805, 0.6734693646430969),\n",
       " (8, 0.6724137663841248, 0.6326530575752258),\n",
       " (9, 0.6755486130714417, 0.6122449040412903),\n",
       " (10, 0.6786834001541138, 0.6326530575752258),\n",
       " (11, 0.6786834001541138, 0.6326530575752258),\n",
       " (12, 0.6755486130714417, 0.5714285969734192)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer_stats\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 0:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 0: Accuracy = 0.5501567125320435, Threshold = 0.9183673858642578\n",
      "Computing stats on WiC dev set for Layer 1:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:58<00:00,  1.17s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 1: Accuracy = 0.5721003413200378, Threshold = 0.8367347121238708\n",
      "Computing stats on WiC dev set for Layer 2:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:58<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 2: Accuracy = 0.5987460613250732, Threshold = 0.795918345451355\n",
      "Computing stats on WiC dev set for Layer 3:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:58<00:00,  1.17s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 3: Accuracy = 0.6316614151000977, Threshold = 0.795918345451355\n",
      "Computing stats on WiC dev set for Layer 4:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 4: Accuracy = 0.6347962617874146, Threshold = 0.7346938848495483\n",
      "Computing stats on WiC dev set for Layer 5:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 5: Accuracy = 0.6520376205444336, Threshold = 0.6530612111091614\n",
      "Computing stats on WiC dev set for Layer 6:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 6: Accuracy = 0.6567398309707642, Threshold = 0.6530612111091614\n",
      "Computing stats on WiC dev set for Layer 7:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.19s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 7: Accuracy = 0.6708464026451111, Threshold = 0.5918367505073547\n",
      "Computing stats on WiC dev set for Layer 8:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.23s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 8: Accuracy = 0.6724137663841248, Threshold = 0.5306122303009033\n",
      "Computing stats on WiC dev set for Layer 9:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.22s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 9: Accuracy = 0.7037617564201355, Threshold = 0.5306122303009033\n",
      "Computing stats on WiC dev set for Layer 10:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.24s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 10: Accuracy = 0.7006269693374634, Threshold = 0.5306122303009033\n",
      "Computing stats on WiC dev set for Layer 11:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.22s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 11: Accuracy = 0.6959247589111328, Threshold = 0.44897958636283875\n",
      "Computing stats on WiC dev set for Layer 12:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 12: Accuracy = 0.6771159768104553, Threshold = 0.44897958636283875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "laser_layer_stats = [] \n",
    "\n",
    "for layer in range(bert.layers+1):\n",
    "    laser = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/{layer}/version_laser_2048_2_0-0001.ckpt')\n",
    "    laser.eval()\n",
    "    best = 0\n",
    "    best_threshold = 0\n",
    "    print(f\"Computing stats on WiC dev set for Layer {layer}:\")\n",
    "    for threshold in tqdm(threshold_space):\n",
    "        target = []\n",
    "        predicted = []\n",
    "        for batch in val_dl:\n",
    "            context1, context2, pos, labels = batch\n",
    "            context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "            context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "            context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "            labels = list(map(lambda x: label2id[x], labels))\n",
    "\n",
    "            c1 = bert.extract_representation(context1, layer)\n",
    "            c2 = bert.extract_representation(context2, layer)\n",
    "            \n",
    "            c1 = laser(c1)\n",
    "            c2 = laser(c2)\n",
    "\n",
    "            predicted.extend(threshold_classifier(c1, c2, threshold))\n",
    "            target.extend(labels)\n",
    "        acc = accuracy(predicted, target)\n",
    "        if acc > best:\n",
    "            best = acc\n",
    "            best_threshold = threshold\n",
    "    print(f\"Numbers for layer {layer}: Accuracy = {best}, Threshold = {best_threshold}\")\n",
    "    laser_layer_stats.append((layer, best, best_threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "## predict on test data\n",
    "predictions = []\n",
    "\n",
    "laser = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/10/version_laser_2048_2_0-0001.ckpt')\n",
    "laser.eval()\n",
    "\n",
    "# threshold = 0.5306122303009033\n",
    "threshold = 0.449"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for batch in test_dl:\n",
    "    context1, context2, pos = batch\n",
    "    context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "    context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "    context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "    c1 = bert.extract_representation(context1, 9)\n",
    "    c2 = bert.extract_representation(context2, 9)\n",
    "\n",
    "    c1 = laser(c1)\n",
    "    c2 = laser(c2)\n",
    "\n",
    "    predictions.extend(threshold_classifier(c1, c2, threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_labels = ['T' if p == 1 else 'F' for p in predictions]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"output.txt\", \"w\") as f:\n",
    "    for label in predicted_labels:\n",
    "        f.write(f\"{label}\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing stats on WiC dev set for Layer 0:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:00<00:00,  1.21s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 0: Accuracy = 0.5611284971237183, Threshold = 0.9387755393981934\n",
      "Computing stats on WiC dev set for Layer 1:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.22s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 1: Accuracy = 0.5799372792243958, Threshold = 0.8979592323303223\n",
      "Computing stats on WiC dev set for Layer 2:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [01:01<00:00,  1.22s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 2: Accuracy = 0.6065830588340759, Threshold = 0.8571428656578064\n",
      "Computing stats on WiC dev set for Layer 3:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.19s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 3: Accuracy = 0.6238244771957397, Threshold = 0.8367347121238708\n",
      "Computing stats on WiC dev set for Layer 4:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.19s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 4: Accuracy = 0.6332288384437561, Threshold = 0.795918345451355\n",
      "Computing stats on WiC dev set for Layer 5:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 5: Accuracy = 0.6630094051361084, Threshold = 0.795918345451355\n",
      "Computing stats on WiC dev set for Layer 6:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 6: Accuracy = 0.6520376205444336, Threshold = 0.7755101919174194\n",
      "Computing stats on WiC dev set for Layer 7:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 7: Accuracy = 0.6645768284797668, Threshold = 0.7346938848495483\n",
      "Computing stats on WiC dev set for Layer 8:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 8: Accuracy = 0.6786834001541138, Threshold = 0.7551020383834839\n",
      "Computing stats on WiC dev set for Layer 9:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 9: Accuracy = 0.6833855509757996, Threshold = 0.7142857313156128\n",
      "Computing stats on WiC dev set for Layer 10:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:59<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 10: Accuracy = 0.6833855509757996, Threshold = 0.7551020383834839\n",
      "Computing stats on WiC dev set for Layer 11:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:58<00:00,  1.18s/it]\n",
      "  0%|          | 0/50 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 11: Accuracy = 0.6739811897277832, Threshold = 0.7755101919174194\n",
      "Computing stats on WiC dev set for Layer 12:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50/50 [00:58<00:00,  1.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numbers for layer 12: Accuracy = 0.653605043888092, Threshold = 0.6938775181770325\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "## add ser.\n",
    "ser_layer_stats = [] \n",
    "\n",
    "for layer in range(bert.layers+1):\n",
    "    laser = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/{layer}/version_ser_2048_2_0-0001-v1.ckpt')\n",
    "    laser.eval()\n",
    "    best = 0\n",
    "    best_threshold = 0\n",
    "    print(f\"Computing stats on WiC dev set for Layer {layer}:\")\n",
    "    for threshold in tqdm(threshold_space):\n",
    "        target = []\n",
    "        predicted = []\n",
    "        for batch in val_dl:\n",
    "            context1, context2, pos, labels = batch\n",
    "            context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "            context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "            context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "            labels = list(map(lambda x: label2id[x], labels))\n",
    "\n",
    "            c1 = bert.extract_representation(context1, layer)\n",
    "            c2 = bert.extract_representation(context2, layer)\n",
    "            \n",
    "            c1 = laser(c1)\n",
    "            c2 = laser(c2)\n",
    "\n",
    "            predicted.extend(threshold_classifier(c1, c2, threshold))\n",
    "            target.extend(labels)\n",
    "        acc = accuracy(predicted, target)\n",
    "        if acc > best:\n",
    "            best = acc\n",
    "            best_threshold = threshold\n",
    "    print(f\"Numbers for layer {layer}: Accuracy = {best}, Threshold = {best_threshold}\")\n",
    "    ser_layer_stats.append((layer, best, best_threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../data/original_layer_threshold.csv\", \"w\") as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(['model', 'class', 'layer', 'accuracy', 'threshold'])\n",
    "    for result in layer_stats:\n",
    "        layer, accuracy, threshold = result\n",
    "        writer.writerow(['bert-base-uncased', 'original', layer, round(threshold, 4), round(accuracy, 4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../data/laser_layer_threshold.csv\", \"w\") as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(['model', 'class', 'layer', 'accuracy', 'threshold'])\n",
    "    for result in laser_layer_stats:\n",
    "        layer, accuracy, threshold = result\n",
    "        writer.writerow(['bert-base-uncased', 'laser', layer, round(threshold, 4), round(accuracy, 4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"../data/ser_layer_threshold.csv\", \"w\") as f:\n",
    "    writer = csv.writer(f)\n",
    "    writer.writerow(['model', 'class', 'layer', 'accuracy', 'threshold'])\n",
    "    for result in ser_layer_stats:\n",
    "        layer, accuracy, threshold = result\n",
    "        writer.writerow(['bert-base-uncased', 'ser', layer, round(threshold, 4), round(accuracy, 4)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, 0.5501567125320435, 0.9183673858642578),\n",
       " (1, 0.5721003413200378, 0.8367347121238708),\n",
       " (2, 0.5987460613250732, 0.795918345451355),\n",
       " (3, 0.6316614151000977, 0.795918345451355),\n",
       " (4, 0.6347962617874146, 0.7346938848495483),\n",
       " (5, 0.6520376205444336, 0.6530612111091614),\n",
       " (6, 0.6567398309707642, 0.6530612111091614),\n",
       " (7, 0.6708464026451111, 0.5918367505073547),\n",
       " (8, 0.6724137663841248, 0.5306122303009033),\n",
       " (9, 0.7037617564201355, 0.5306122303009033),\n",
       " (10, 0.7006269693374634, 0.5306122303009033),\n",
       " (11, 0.6959247589111328, 0.44897958636283875),\n",
       " (12, 0.6771159768104553, 0.44897958636283875)]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "laser_layer_stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = load_wic(\"test\")\n",
    "test_dl = DataLoader(test, num_workers = 4, batch_size = 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 11/11 [00:02<00:00,  4.20it/s]\n"
     ]
    }
   ],
   "source": [
    "laser = NonLinearApproximator.load_from_checkpoint(f'{auth1_path}/makesense_logs/bert/10/version_2048_2_0-0001.ckpt')\n",
    "predicted = []\n",
    "for batch in tqdm(test_dl):\n",
    "    context1, context2, pos = batch\n",
    "    context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "    context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "    context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "    c1 = bert.extract_representation(context1, 10)\n",
    "    c2 = bert.extract_representation(context2, 10)\n",
    "\n",
    "    c1 = laser(c1)\n",
    "    c2 = laser(c2)\n",
    "\n",
    "    predicted.extend(threshold_classifier(c1, c2, 0.44897958636283875))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1400"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(predicted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_test_labels = ['T' if p == 1 else 'F' for p in predicted]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"layer_10_cosine_threshold.txt\", \"w\") as f:\n",
    "    f.write(\"\\n\".join(predicted_test_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}