{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.corpus import wordnet as wn\n",
    "\n",
    "from whic_utils import pairwise_direction, edit_distance, find_index, load_whic\n",
    "\n",
    "from minicons.utils import argmin\n",
    "\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = load_whic('dev')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = list(zip(*train))[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({'0': 1421, '1': 283})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = wn.synsets('chess')[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Synset('game.n.01')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wn.synsets('game')[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.shortest_path_distance(wn.synsets('games')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, n = pairwise_direction('train')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3693"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "def shortest_distance(word1, word2):\n",
    "    synset1 = wn.synsets(word1)\n",
    "    synset2 = wn.synsets(word2)\n",
    "    \n",
    "    distances = []\n",
    "    \n",
    "    for s1 in synset1:\n",
    "        for s2 in synset2:\n",
    "            distance = s1.shortest_path_distance(s2)\n",
    "            if distance is not None and distance != 0:\n",
    "                distances.append(distance)\n",
    "    \n",
    "    sd = distances[argmin(distances)]\n",
    "            \n",
    "    return sd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "dists = []\n",
    "for i, entry in enumerate(p):\n",
    "    (c1, idx1), (c2, idx2), label = entry\n",
    "    word1 = c1.split()[idx1]\n",
    "    word2 = c2.split()[idx2]\n",
    "    \n",
    "    try:\n",
    "        d = shortest_distance(word1, word2)\n",
    "        if d == 0:\n",
    "            print(word1, word2, i, d)\n",
    "        dists.append(d)\n",
    "    except:\n",
    "        print(word1, word2, i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['they moved their arms and legs and bodies .', 0],\n",
       " ['they analyzed the river into three parts .', 6],\n",
       " '1']"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p[1104]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({1: 2053, 2: 1634, 3: 3})"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(dists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['they had three children .', 1],\n",
       " ['she was the mother of many offspring .', 6],\n",
       " '1']"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p[126]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = 'test'\n",
    "# whic = []\n",
    "# with open(f\"../data/whic/{dataset}.tsv\", \"r\") as f:\n",
    "#     for line in f:\n",
    "#         c1, w1, c2, w2, label = line.strip().split(\"\\t\")\n",
    "        \n",
    "#         if w1 == 'child' and 'children' in c1:\n",
    "#             w1 = 'children'\n",
    "#         if w2 == 'child' and 'children' in c2:\n",
    "#             w2 = 'children'\n",
    "#         if w1 == 'cry' and 'cries' in c1:\n",
    "#             w1 = 'cries'\n",
    "#         if w2 == 'cry' and 'cries' in c2:\n",
    "#             w2 = 'cries'\n",
    "#         if w1 == 'body' and 'bodies' in c1:\n",
    "#             w1 = 'bodies'\n",
    "#         if w2 == 'body' and 'bodies' in c2:\n",
    "#             w2 = 'bodies'    \n",
    "\n",
    "#         idx1, idx2 = [x[0] for x in (find_index(c1, w1), find_index(c2, w2))]\n",
    "        \n",
    "#         if c1.split()[idx1] != w1:\n",
    "#             print(f\"word1 error: {c1.split()[idx1], w1}\")\n",
    "#         if c2.split()[idx2] != w2:\n",
    "#             print(f\"word2 error: {c2.split()[idx2], w2}\")\n",
    "\n",
    "#         context1 = [punctuate(c1), idx1]\n",
    "#         context2 = [punctuate(c2), idx2]\n",
    "\n",
    "#         whic.append([context1, context2, label])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
