{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "from minicons import cwe\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.model_selection import GridSearchCV, PredefinedSplit\n",
    "\n",
    "from sklearn.metrics import classification_report, accuracy_score, f1_score, matthews_corrcoef\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from hypopt import GridSearch\n",
    "\n",
    "import os\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert = cwe.CWE('bert-base-uncased', 'cuda:1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def punctuate(line):\n",
    "    if line[-1] not in ['.', '?', '!']:\n",
    "        if line[-1] == '\\'' or line[-1] == '\"':\n",
    "            line = line[:-1] + ' .' + line[-1]\n",
    "        else:\n",
    "            line = line + ' .'\n",
    "    else:\n",
    "        line = line[:-1] + ' ' + line[-1]\n",
    "    return line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Edit distance = number of operations (deletion, insertion, replace)\n",
    "## needed to convert one string to the other\n",
    "\n",
    "def edit_distance(word1, word2):\n",
    "    m, n = len(word1), len(word2) \n",
    "    dp = [[0 for x in range(n + 1)] for x in range(m + 1)] \n",
    "\n",
    "    for i in range(m + 1): \n",
    "        for j in range(n + 1): \n",
    "  \n",
    "            if i == 0: \n",
    "                dp[i][j] = j    \n",
    "            elif j == 0: \n",
    "                dp[i][j] = i    \n",
    "            elif word1[i-1] == word2[j-1]: \n",
    "                dp[i][j] = dp[i-1][j-1] \n",
    "            else: \n",
    "                dp[i][j] = 1 + min(dp[i][j-1], dp[i-1][j], dp[i-1][j-1]) \n",
    "  \n",
    "    return dp[m][n]\n",
    "\n",
    "def argmin(lst):\n",
    "    return min(range(len(lst)), key=lambda x: lst[x])\n",
    "\n",
    "def find_index(context, word):\n",
    "    tokenized = context.split()\n",
    "    editdists = [edit_distance(w, word) for w in tokenized]\n",
    "    \n",
    "    index = argmin(editdists)\n",
    "    \n",
    "    return index, tokenized[index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def directionaldataset(whic_data = 'train'):\n",
    "    whic_temp = []\n",
    "    with open(f\"../data/whic/{whic_data}.tsv\", \"r\") as f:\n",
    "        for line in f:\n",
    "            whic_temp.append(line.strip().split(\"\\t\"))\n",
    "    \n",
    "    positive_samples = []\n",
    "    for entry in whic_temp:\n",
    "        c1, w1, c2, w2, label = entry\n",
    "        if label == '1':\n",
    "            positive_samples.append([c1, w1, c2, w2])\n",
    "    \n",
    "    whic = []\n",
    "\n",
    "    for entry in whic_temp:\n",
    "        c1, w1, c2, w2, label = entry\n",
    "        if label == '0':\n",
    "            if [c2, w2, c1, w1] in positive_samples:\n",
    "#                 label = '-1'\n",
    "                idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]\n",
    "            \n",
    "                context1 = [punctuate(c1), idx1]\n",
    "                context2 = [punctuate(c2), idx2]\n",
    "\n",
    "                whic.append([context1, context2, label])\n",
    "            else:\n",
    "                pass\n",
    "    \n",
    "    return whic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directionaldataset('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_whic(dataset = 'train'):\n",
    "    whic = []\n",
    "    with open(f\"../data/whic/{dataset}.tsv\", \"r\") as f:\n",
    "        for line in f:\n",
    "            c1, w1, c2, w2, label = line.strip().split(\"\\t\")\n",
    "            \n",
    "            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]\n",
    "            \n",
    "            context1 = [punctuate(c1), idx1]\n",
    "            context2 = [punctuate(c2), idx2]\n",
    "            \n",
    "            whic.append([context1, context2, label])\n",
    "            \n",
    "    return whic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pairwise_direction(dataset = 'train'):\n",
    "    whic_temp = []\n",
    "    with open(f\"../data/whic/{dataset}.tsv\", \"r\") as f:\n",
    "        for line in f:\n",
    "            c1, w1, c2, w2, label = line.strip().split(\"\\t\")\n",
    "            \n",
    "            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]\n",
    "            \n",
    "            context1 = [punctuate(c1), idx1]\n",
    "            context2 = [punctuate(c2), idx2]\n",
    "            \n",
    "            whic_temp.append([context1, w1, context2, w2, label])\n",
    "    \n",
    "    positive_samples = []\n",
    "    negative_samples = []\n",
    "    for entry in whic_temp:\n",
    "        c1, w1, c2, w2, label = entry\n",
    "        if label == '1':\n",
    "            positive_samples.append([c1, c2, label])\n",
    "            negative_samples.append([c2, c1, '0'])\n",
    "    \n",
    "    return positive_samples, negative_samples\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, n = pairwise_direction('dev')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(p + n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pairwise_context(dataset = 'train'):\n",
    "    positive_samples = []\n",
    "    negative_samples = []\n",
    "    with open(f\"../data/whic/{dataset}.tsv\", \"r\") as f:\n",
    "        for line in f:\n",
    "            c1, w1, c2, w2, label = line.strip().split(\"\\t\")\n",
    "            \n",
    "            idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]\n",
    "            \n",
    "            context1 = [punctuate(c1), idx1]\n",
    "            context2 = [punctuate(c2), idx2]\n",
    "    \n",
    "            if label == '1':\n",
    "                positive_samples.append([context1, w1, context2, w2, label])\n",
    "            if label == '0':\n",
    "                negative_samples.append([context1, w1, context2, w2, label])\n",
    "        \n",
    "    pos = []\n",
    "    neg = []\n",
    "    \n",
    "    word_idx = defaultdict(list)\n",
    "    \n",
    "    counter = -1\n",
    "    for cp1, wp1, cp2, wp2, lp in positive_samples:\n",
    "        for cn1, wn1, cn2, wn2, ln in negative_samples:\n",
    "            if wp1 == wn1 and wp2 != wn2 and cp1 == cn1:\n",
    "                pos.append([cp1, cp2, lp])\n",
    "                neg.append([cp1, cn2, ln])\n",
    "                counter+=1\n",
    "                word_idx[wp1].append(counter)\n",
    "    return pos, neg, word_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def whic_set(dataset, layer = 0, model = bert):\n",
    "    data_dl = DataLoader(dataset, num_workers = 4, batch_size = 128)\n",
    "    \n",
    "    vectors = []\n",
    "    labels = []\n",
    "    \n",
    "    for batch in data_dl:\n",
    "        context1, context2, label = batch\n",
    "        context1, context2 = [list(zip(*x)) for x in [context1, context2]]\n",
    "        context1 = [(c, [i.item(), i.item()+1]) for c, i in context1]\n",
    "        context2 = [(c, [i.item(), i.item()+1]) for c, i in context2]\n",
    "\n",
    "        label = list(map(lambda x: int(x), label))\n",
    "\n",
    "        c1 = model.extract_representation(context1, layer)\n",
    "        c2 = model.extract_representation(context2, layer)\n",
    "        \n",
    "        labels.extend(label)\n",
    "        \n",
    "        vectors.extend(torch.cat((c1, c2), 1))\n",
    "    \n",
    "    labels = torch.tensor(labels).numpy()\n",
    "    vectors = torch.stack(vectors).numpy()\n",
    "    \n",
    "    return vectors, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYER = 9\n",
    "\n",
    "train = load_whic('train')\n",
    "val = load_whic('dev')\n",
    "test = load_whic('test')\n",
    "\n",
    "train_X, train_y = whic_set(train, layer = LAYER)\n",
    "val_X, val_y = whic_set(val, layer = LAYER)\n",
    "test_X, test_y = whic_set(test, layer = LAYER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.concatenate((train_X, val_X), axis = 0)\n",
    "y = np.concatenate((train_y, val_y), axis = 0)\n",
    "\n",
    "cv_vals = [-1 if i < len(train_X) else 0 for i in range(len(X))]\n",
    "\n",
    "pds = PredefinedSplit(test_fold=cv_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp_probe = MLPClassifier(\n",
    "    random_state=1234, \n",
    "    max_iter=100, \n",
    "    n_iter_no_change=5,\n",
    "    learning_rate_init = 0.001,\n",
    "    early_stopping=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_grid = {\n",
    "  'hidden_layer_sizes': [(100,), (128,), (256,)],\n",
    "  'alpha': [1, 0.1, 0.01, 0.001]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = GridSearchCV(mlp_probe, cv = pds, param_grid=param_grid, scoring = 'neg_log_loss', verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# best = MLPClassifier(\n",
    "#     random_state=1, \n",
    "#     max_iter=100, \n",
    "#     n_iter_no_change=5,\n",
    "#     learning_rate_init = 0.001,\n",
    "#     early_stopping=True,\n",
    "#     **clf.best_params_,\n",
    "#     verbose = True\n",
    "# )\n",
    "# best.fit(train_X, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classification_report(test_y, clf.predict(test_X)))\n",
    "\n",
    "'''\n",
    " precision    recall  f1-score   support\n",
    "\n",
    "           0       0.83      0.91      0.87      4098\n",
    "           1       0.57      0.39      0.46      1263\n",
    "\n",
    "    accuracy                           0.79      5361\n",
    "   macro avg       0.70      0.65      0.66      5361\n",
    "weighted avg       0.77      0.79      0.77      5361\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_directional = pairwise_direction('test')\n",
    "dir_p_X, dir_p_y = whic_set(test_directional[0], layer = LAYER)\n",
    "dir_n_X, dir_n_y = whic_set(test_directional[1], layer = LAYER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(accuracy_score(dir_p_y, clf.predict(dir_p_X)))\n",
    "# 12 = 31.27"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.predict(dir_n_X).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "((clf.predict(dir_p_X) == 1) * (clf.predict(dir_n_X) == 0)).mean()\n",
    "# 12 - 0.2914"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_context = pairwise_context('test')\n",
    "context_p_X, context_p_y = whic_set(test_context[0], layer = LAYER)\n",
    "context_n_X, context_n_y = whic_set(test_context[1], layer = LAYER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = defaultdict(float)\n",
    "for word, idx in test_context[2].items():\n",
    "    positive = best.predict(context_p_X[idx])\n",
    "    negative = best.predict(context_n_X[idx])\n",
    "    \n",
    "    pairwise_acc = ((positive == 1) * (negative == 0)).mean()\n",
    "    \n",
    "    results[word] = pairwise_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.argmax(list(results.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def context_sensitivity(classifier):\n",
    "    positive, negative, lexicon = pairwise_context('test')\n",
    "    context_p_X, context_p_y = whic_set(positive)\n",
    "    context_n_X, context_n_y = whic_set(negative)\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(results.keys())[162]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(best.predict(context_n_X) == 0).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "((best.predict(context_p_X) == 1) * (best.predict(context_n_X) == 0)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(best.predict(context_p_X) == 0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_context[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "best.predict(context_n_X[:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "list(enumerate(test_context[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_context[0][87:95]\n",
    "((best.predict(context_p_X[87:95]) == 1) * (best.predict(context_n_X[87:95]) == 0)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert.tokenizer.convert_ids_to_tokens(bert.encode_text(\"three quarters of the earth ' s surface is covered by water .\")[0][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_score(test_y, np.zeros_like(test_y))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}