{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d437fe6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import inflect\n",
    "import json\n",
    "import nltk\n",
    "import pyinflect\n",
    "import spacy\n",
    "\n",
    "from collections import defaultdict\n",
    "from ordered_set import OrderedSet\n",
    "from nltk.corpus import wordnet as wn\n",
    "from nltk.corpus import wordnet2021 as wn2\n",
    "from wordfreq import word_frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6ee3cfd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "inflector = inflect.engine()\n",
    "nlp = spacy.load('en_core_web_sm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13f2b508",
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_plural(word):\n",
    "    return inflector.singular_noun(word) is not False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "47217515",
   "metadata": {},
   "outputs": [],
   "source": [
    "DIR = \"../data/gqa_entities/\"\n",
    "scene_nouns_path = f\"{DIR}/scene_nouns.csv\"\n",
    "question_nouns_path = f\"{DIR}/question_nouns.csv\"\n",
    "overlapping_nouns_path = f\"{DIR}/overlapping_nouns.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "308f6b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_csv(path):\n",
    "    data = []\n",
    "    with open(path, \"r\") as f:\n",
    "        reader = csv.DictReader(f)\n",
    "        for line in reader:\n",
    "            data.append(line)\n",
    "    return data\n",
    "\n",
    "def unique_words(lst):\n",
    "    return OrderedSet([entry['word'] for entry in lst])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "372d97d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "scene_nouns = unique_words(read_csv(scene_nouns_path))\n",
    "question_nouns = unique_words(read_csv(question_nouns_path))\n",
    "overlapping_nouns = unique_words(read_csv(overlapping_nouns_path))\n",
    "\n",
    "unique_nouns = scene_nouns.union(question_nouns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8e3c5aec",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2041"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(unique_nouns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8623251b",
   "metadata": {},
   "outputs": [],
   "source": [
    "senses = []\n",
    "for word in unique_nouns:\n",
    "    word = \"_\".join(word.split(\" \")) if \" \" in word else word\n",
    "    try:\n",
    "        default_sense = wn.synsets(word)[0]\n",
    "        default_sense_name = default_sense.name()\n",
    "    except:\n",
    "        default_sense_name = \"NF\"\n",
    "    \n",
    "    senses.append((word, default_sense_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4a66d33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # save senses to dir\n",
    "# with open(f\"{DIR}/noun-senses.csv\", \"w\") as f: \n",
    "#     writer = csv.writer(f)\n",
    "#     writer.writerow([\"noun\", \"sense\"])\n",
    "#     writer.writerows(senses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 394,
   "id": "c7c6d0c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read annotated csv from gdrive\n",
    "annotated_senses_raw = read_csv(f\"{DIR}/noun-senses-annotated.csv\")\n",
    "annotated_senses = []\n",
    "manual = []\n",
    "leftover = []\n",
    "\n",
    "for entry in annotated_senses_raw:\n",
    "    word = entry['noun']\n",
    "    sense = entry['sense']\n",
    "    if entry['marked'] == '' or entry['marked'] == '4':\n",
    "        if entry['replacement'] != '' and entry['marked'] != '4':\n",
    "            sense = entry['replacement']\n",
    "            if \"wn2\" in entry['notes']:\n",
    "                sense = wn2.synsets(word)[0].name()\n",
    "            annotated_senses.append((word, sense))\n",
    "        else:        \n",
    "            if \"manual\" in entry['notes'] or \"toy\" in entry['notes']:\n",
    "                manual.append(entry)\n",
    "            else:\n",
    "                annotated_senses.append((word, sense))\n",
    "    else:\n",
    "        leftover.append(entry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 395,
   "id": "e0954c9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2041, 2041)"
      ]
     },
     "execution_count": 395,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(leftover) + len(manual) + len(annotated_senses), len(senses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 396,
   "id": "bc5b66a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Current filtration pipeline\n",
    "get all hypernyms for each entity\n",
    "look at counts of hypernyms across samples as a pct of appearance in hypernym path\n",
    "remove common ones (after manual inspection)\n",
    "    e.g., entity, whole, physical entity, artifact, instrumentality, \n",
    "then sample based on freq. (maximum 4, including entity)\n",
    "'''\n",
    "def hypernym_path_lemmas(sense):\n",
    "    try:\n",
    "        synset = wn2.synset(sense)\n",
    "    except:\n",
    "        synset = wn.synset(sense)\n",
    "    hp = synset.hypernym_paths()[0]\n",
    "    lemmas = [\" \".join(s.lemma_names()[0].split(\"_\")) for s in hp]\n",
    "    return lemmas\n",
    "\n",
    "# # e.g.\n",
    "# hypernym_path_lemmas(\"chair.n.01\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 397,
   "id": "d14e08c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "lemma_counts = defaultdict(float)\n",
    "for word, sense in annotated_senses:\n",
    "    try:\n",
    "        hypernym_lemmas = hypernym_path_lemmas(sense)\n",
    "    except:\n",
    "        print(word, sense)\n",
    "    for hl in hypernym_lemmas:\n",
    "        lemma_counts[hl] += 1/len(annotated_senses)\n",
    "lemma_counts = dict(lemma_counts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 399,
   "id": "4f1c2e88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('entity', 0.9975845410628059),\n",
       " ('physical entity', 0.9243156199677975),\n",
       " ('object', 0.5636070853462178),\n",
       " ('whole', 0.5346215780998409),\n",
       " ('artifact', 0.49355877616747357),\n",
       " ('instrumentality', 0.24879227053140165),\n",
       " ('food', 0.21175523349436445),\n",
       " ('causal agent', 0.1698872785829311),\n",
       " ('matter', 0.16908212560386507),\n",
       " ('organism', 0.166666666666667)]"
      ]
     },
     "execution_count": 399,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sorted(lemma_counts.items(), key=lambda x: x[1], reverse=True)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 434,
   "id": "6116167a",
   "metadata": {},
   "outputs": [],
   "source": [
    "avoid = [\"entity\", \"physical entity\", \"whole\", \"artifact\", \"instrumentality\", \"organism\", \n",
    "         \"matter\", \"solid\", \"abstraction\", \"commodity\", \"plant part\", \"foodstuff\", \"plant organ\",\n",
    "         \"furnishing\", \"communication\", \"flavorer\", \"part\", \"act\", \"activity\", \"place of business\", \n",
    "         \"collection\", \"mercantile establishment\", \"unit\", \"visual communication\", \"binary compound\", \n",
    "         \"administrative district\", \"animal material\", \"causal agent\", \"phenomenon\", \"compound\", \n",
    "         \"natural phenomenon\", \"material\", \"natural object\", \"reproductive structure\", \"object\", \n",
    "         \"edible fruit\", \"course\", \"substance\", \"living thing\", \"nutriment\", \"consumer goods\",\n",
    "         \"chordate\", \"big cat\", \"self-propelled vehicle\", \"physical phenomenon\", \"process\",\n",
    "        \"obstruction\", \"angiosperm\", \"vascular plant\", \"way\", \"craft\", \"conveyance\"]\n",
    "\n",
    "# sample upto 3 based on freq, then resort based on rank in hierarchy = final hierarchy\n",
    "\n",
    "N_SAMPLES = 3\n",
    "sampled_paths = defaultdict(list)\n",
    "\n",
    "for word, sense in annotated_senses:\n",
    "    word = \" \".join(word.split(\"_\"))\n",
    "    try:\n",
    "        hypernym_lemmas = list(OrderedSet(hypernym_path_lemmas(sense)))\n",
    "    except:\n",
    "        print(word, sense)\n",
    "    freq_ranks = []\n",
    "    for i, hl in enumerate(hypernym_lemmas[:-1]):\n",
    "        if hl not in avoid:\n",
    "            freq_ranks.append((hl, i, word_frequency(hl, 'en')))\n",
    "    sampling_amt = min(N_SAMPLES, len(freq_ranks))\n",
    "    \n",
    "    # sampled based on freq\n",
    "    sampled = sorted(freq_ranks, key = lambda x: x[-1], reverse=True)[:sampling_amt]\n",
    "    \n",
    "    # reorder based on hierarchy\n",
    "    rank_sorted = sorted(sampled, key = lambda x: x[1])\n",
    "    \n",
    "    # build hypernymy path/chain\n",
    "    sampled_path = [word] + list(reversed([hypernym for hypernym, rank, freq in rank_sorted]))\n",
    "    sampled_paths[word] = sampled_path\n",
    "    \n",
    "sampled_paths = dict(sampled_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 435,
   "id": "d00465e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['boat', 'vessel', 'vehicle']"
      ]
     },
     "execution_count": 435,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sampled_paths['boat']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 515,
   "id": "f21f3996",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[Synset('entity.n.01'),\n",
       "  Synset('physical_entity.n.01'),\n",
       "  Synset('object.n.01'),\n",
       "  Synset('whole.n.02'),\n",
       "  Synset('artifact.n.01'),\n",
       "  Synset('decoration.n.01'),\n",
       "  Synset('adornment.n.01'),\n",
       "  Synset('jewelry.n.01'),\n",
       "  Synset('ring.n.08')]]"
      ]
     },
     "execution_count": 515,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wn2.synset('ring.n.08').hypernym_paths()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 875,
   "id": "1ec17f57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['entity',\n",
       " 'physical entity',\n",
       " 'object',\n",
       " 'whole',\n",
       " 'artifact',\n",
       " 'instrumentality',\n",
       " 'equipment',\n",
       " 'game equipment',\n",
       " 'ball',\n",
       " 'basketball']"
      ]
     },
     "execution_count": 875,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hypernym_path_lemmas(\"basketball.n.02\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 561,
   "id": "e8683c2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "entity\n",
      "physical entity\n",
      "causal agent\n",
      "organism\n",
      "animal\n",
      "chordate\n",
      "vertebrate\n",
      "mammal\n",
      "placental\n",
      "carnivore\n",
      "canine\n",
      "wolf\n"
     ]
    }
   ],
   "source": [
    "freqs = defaultdict(float)\n",
    "for s in wn2.synset(\"wolf.n.03\").hypernym_paths()[0]:\n",
    "    ln = \" \".join(s.lemma_names()[0].split(\"_\"))\n",
    "    freqs[ln] = word_frequency(ln, \"en\")\n",
    "    print(ln)\n",
    "# sorted(freqs.items(), key=lambda x: x[1], reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 876,
   "id": "767fc641",
   "metadata": {},
   "outputs": [],
   "source": [
    "# save to csv\n",
    "# with open(f\"{DIR}/noun-hypernymy-paths.csv\", \"w\") as f:\n",
    "#     writer = csv.writer(f)\n",
    "#     writer.writerow(['noun', 'hyp-1', 'hyp-2', 'hyp-3'])\n",
    "#     for word, chain in sampled_paths.items():\n",
    "#         if len(chain) > 1:\n",
    "#             writer.writerow(chain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8b80fc23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# read annotated csv\n",
    "annotated_hypernym_paths_raw = read_csv(f\"{DIR}/noun-hypernymy-paths-annotated.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bbcb081c",
   "metadata": {},
   "outputs": [],
   "source": [
    "annotated_hypernym_paths = defaultdict(list)\n",
    "for entry in annotated_hypernym_paths_raw:\n",
    "    hyps = [v for k,v in entry.items() if 'hyp' in k and v != '']\n",
    "    if entry['flag'] != '1':\n",
    "        annotated_hypernym_paths[entry['noun']] = hyps\n",
    "        \n",
    "annotated_hypernym_paths = dict(annotated_hypernym_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "90c7dc99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# annotated_hypernym_paths\n",
    "with open(f\"{DIR}/noun-hypernyms.json\", \"w\") as f:\n",
    "    json.dump(annotated_hypernym_paths, f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09e889a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
