{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "LOMkspTehAgv",
    "outputId": "0fdd0f73-9fc5-44bd-bf86-a16ddf8bba6b"
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "nltk.download('wordnet')\n",
    "\n",
    "import json\n",
    "import networkx as nx\n",
    "from nltk.corpus import wordnet as wn\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "import inflect"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "MODEL_NAME = \"meta-llama/Meta-Llama-3-8B\"\n",
    "# MODEL_NAME = \"google/gemma-2b\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "vocab = tokenizer.get_vocab()\n",
    "vocab_set = set(vocab.keys())\n",
    "\n",
    "p = inflect.engine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_hyponym_lemmas(synset):\n",
    "    hyponyms = synset.hyponyms()\n",
    "    lemmas = set()\n",
    "    for hyponym in hyponyms:\n",
    "        lemmas.update(lemma.name() for lemma in hyponym.lemmas())\n",
    "        lemmas.update(get_all_hyponym_lemmas(hyponym))  # Recursively get lemmas from hyponyms,\n",
    "    \n",
    "    return lemmas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_noun_synsets = list(wn.all_synsets(pos=wn.NOUN))\n",
    "noun_lemmas = {}\n",
    "for s in all_noun_synsets:\n",
    "    lemmas = get_all_hyponym_lemmas(s)\n",
    "    # add and remove space bc of how gemma vocab works\n",
    "    lemmas = vocab_set.intersection({\"Ġ\" + l for l in lemmas})\n",
    "    noun_lemmas[s.name()] = {l[1:] for l in lemmas}\n",
    "\n",
    "large_nouns = {k: v for k, v in noun_lemmas.items() if len(v) > 5}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82115\n",
      "2016\n"
     ]
    }
   ],
   "source": [
    "print(len(all_noun_synsets))\n",
    "print(len(large_nouns))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "no ancestors for entity.n.01\n"
     ]
    }
   ],
   "source": [
    "# Construct the hypernym inclusion graph among large categories\n",
    "G_noun = nx.DiGraph()\n",
    "\n",
    "nodes = list(large_nouns.keys())\n",
    "for key in nodes:\n",
    "    for path in wn.synset(key).hypernym_paths():\n",
    "        # ancestors included in the cleaned set\n",
    "        ancestors = [s.name() for s in path if s.name() in nodes]\n",
    "        if len(ancestors) > 1:\n",
    "            G_noun.add_edge(ancestors[-2],key) # first entry is itself\n",
    "        else:\n",
    "            print(f\"no ancestors for {key}\")\n",
    "\n",
    "G_noun = nx.DiGraph(G_noun.subgraph(nodes))\n",
    "\n",
    "# list(G.successors('reptile.n.01'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if a node has only one child, and that child has only one parent, merge the two nodes\n",
    "def merge_nodes(G, lemma_dict):\n",
    "    topological_sorted_nodes = list(reversed(list(nx.topological_sort(G))))\n",
    "    for node in topological_sorted_nodes:\n",
    "        children = list(G.successors(node))\n",
    "        if len(children) == 1:\n",
    "            child = children[0]\n",
    "            parent_lemmas_not_in_child = lemma_dict[node] - lemma_dict[child]\n",
    "            if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <6:\n",
    "                grandchildren = list(G.successors(child))\n",
    "                \n",
    "                if len(parent_lemmas_not_in_child) > 1:\n",
    "                    if len(grandchildren) > 0:\n",
    "                        lemma_dict[node + '.other'] = parent_lemmas_not_in_child\n",
    "                        G.add_edge(node, node + '.other')\n",
    "\n",
    "                # del synset_lemmas[child]\n",
    "                for grandchild in grandchildren:\n",
    "                    G.add_edge(node, grandchild)\n",
    "                G.remove_node(child)\n",
    "                print(f\"merged {node} and {child}\")\n",
    "\n",
    "merge_nodes(G_noun, large_nouns)\n",
    "large_nouns = {k: v for k, v in large_nouns.items() if k in G_noun.nodes()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nx.is_weakly_connected(G_noun)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make a gemma specific version\n",
    "def _noun_to_gemma_vocab_elements(word):\n",
    "    word = word.lower()\n",
    "    plural = p.plural(word)\n",
    "    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]\n",
    "    add_space = [\"Ġ\" + w for w in add_cap_and_plural]\n",
    "    return vocab_set.intersection(add_space)\n",
    "\n",
    "with open('data/noun_synsets_wordnet_llama.json', 'w') as f:\n",
    "    for synset, lemmas in large_nouns.items():\n",
    "        llama_words = []\n",
    "        for w in lemmas:\n",
    "            llama_words.extend(_noun_to_gemma_vocab_elements(w))\n",
    "\n",
    "        f.write(json.dumps({synset: llama_words}) + \"\\n\")\n",
    "        \n",
    "nx.write_adjlist(G_noun, \"data/noun_synsets_wordnet_hypernym_graph_llama.adjlist\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "cats = {}\n",
    "with open('data/noun_synsets_wordnet_llama.json', 'r') as f:\n",
    "    for line in f:\n",
    "        cats.update(json.loads(line))\n",
    "G = nx.read_adjlist(\"data/noun_synsets_wordnet_hypernym_graph_llama.adjlist\", create_using=nx.DiGraph())\n",
    "\n",
    "cats = {k: list(set(v)) for k, v in cats.items() if len(set(v)) > 50}\n",
    "G = nx.DiGraph(G.subgraph(cats.keys()))\n",
    "\n",
    "reversed_nodes = list(reversed(list(nx.topological_sort(G))))\n",
    "for node in reversed_nodes:\n",
    "    children = list(G.successors(node))\n",
    "    if len(children) == 1:\n",
    "        child = children[0]\n",
    "        parent_lemmas_not_in_child = set(cats[node]) - set(cats[child])\n",
    "        if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <5:\n",
    "            grandchildren = list(G.successors(child))\n",
    "            for grandchild in grandchildren:\n",
    "                G.add_edge(node, grandchild)\n",
    "            G.remove_node(child)\n",
    "\n",
    "G = nx.DiGraph(G.subgraph(cats.keys()))\n",
    "sorted_keys = list(nx.topological_sort(G))\n",
    "cats = {k: cats[k] for k in sorted_keys}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## verbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_verb_synsets = list(wn.all_synsets(pos=wn.VERB))\n",
    "\n",
    "G_verb = nx.DiGraph()\n",
    "verb_leaves = [s.name() for s in all_verb_synsets if len(s.hyponyms()) == 0]\n",
    "for leaf in verb_leaves:\n",
    "    for path in wn.synset(leaf).hypernym_paths():\n",
    "        for i in range(len(path) - 1):\n",
    "            G_verb.add_edge(path[i].name(), path[i+1].name())\n",
    "\n",
    "verb_lemmas = {}\n",
    "sorted_verb_synsets = list(reversed(list(nx.topological_sort(G_verb))))\n",
    "for s in sorted_verb_synsets:\n",
    "    # add and remove \"▁\" bc of how gemma vocab works\n",
    "    verb_lemmas[s] = vocab_set.intersection([\"Ġ\" + lemma.name() for lemma in wn.synset(s).lemmas()])\n",
    "    verb_lemmas[s] = {lemma[1:] for lemma in verb_lemmas[s]}\n",
    "    for child in G_verb.successors(s):\n",
    "        verb_lemmas[s].update(verb_lemmas[child])\n",
    "\n",
    "large_verbs = {k: v for k, v in verb_lemmas.items() if len(v) > 5}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "G_verb = nx.DiGraph(G_verb.subgraph(list(large_verbs.keys())))\n",
    "merge_nodes(G_verb, verb_lemmas)\n",
    "large_verbs = {k: v for k, v in large_verbs.items() if k in G_verb.nodes()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _verb_to_gemma_vocab_elements(verb):\n",
    "    #  add in regular verb conjugations to expand vocab\n",
    " \n",
    "    verb = verb.lower()\n",
    "\n",
    "    # some weird wordnet bugs, filter it\n",
    "    if len(verb) < 3 and verb not in [\"be\", \"do\", \"go\"]:\n",
    "        return set()\n",
    "\n",
    "    # Conjugating to past tense\n",
    "    if verb.endswith('e'):\n",
    "        past = verb + 'd'\n",
    "    else:\n",
    "        past = verb + 'ed'\n",
    "    \n",
    "    # Conjugating to present participle\n",
    "    if verb.endswith('e'):\n",
    "        present_participle = verb[:-1] + 'ing'\n",
    "    else:\n",
    "        present_participle = verb + 'ing'\n",
    "    \n",
    "    # 3psg \n",
    "    if verb.endswith((\"o\", \"ch\", \"s\", \"sh\", \"x\", \"z\")):\n",
    "        third_person = verb + 'es'\n",
    "    elif verb.endswith(\"y\") and verb[-2] not in \"aeiou\":\n",
    "        third_person = verb[:-1] + 'ies'\n",
    "    else:\n",
    "        third_person = verb + 's'\n",
    "\n",
    "\n",
    "    tenses = [verb, past, present_participle, third_person]\n",
    "\n",
    "    caps = [w.capitalize() for w in tenses]\n",
    "    \n",
    "    # Add underscore prefix to each word to match vocab tokens\n",
    "    add_space = [\"Ġ\" + w for w in (caps + tenses)]\n",
    "\n",
    "    # Return intersection with a hypothetical vocabulary set\n",
    "    return vocab_set.intersection(add_space)\n",
    "\n",
    "with open('data/verb_synsets_wordnet_llama.json', 'w') as f:\n",
    "    for synset, lemmas in large_verbs.items():\n",
    "        gemma_words = []\n",
    "        for w in lemmas:\n",
    "            gemma_words.extend(_verb_to_gemma_vocab_elements(w))\n",
    "\n",
    "        f.write(json.dumps({synset: gemma_words}) + \"\\n\")\n",
    "        \n",
    "nx.write_adjlist(G_verb, \"data/verb_synsets_wordnet_hypernym_graph_llama.adjlist\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "cats = {}\n",
    "with open('data/verb_synsets_wordnet_llama.json', 'r') as f:\n",
    "    for line in f:\n",
    "        cats.update(json.loads(line))\n",
    "G = nx.read_adjlist(\"data/verb_synsets_wordnet_hypernym_graph_llama.adjlist\", create_using=nx.DiGraph())\n",
    "\n",
    "cats = {k: list(set(v)) for k, v in cats.items() if len(set(v)) > 50}\n",
    "G = nx.DiGraph(G.subgraph(cats.keys()))\n",
    "\n",
    "reversed_nodes = list(reversed(list(nx.topological_sort(G))))\n",
    "for node in reversed_nodes:\n",
    "    children = list(G.successors(node))\n",
    "    if len(children) == 1:\n",
    "        child = children[0]\n",
    "        parent_lemmas_not_in_child = set(cats[node]) - set(cats[child])\n",
    "        if len(list(G.predecessors(child))) == 1 or len(parent_lemmas_not_in_child) <5:\n",
    "            grandchildren = list(G.successors(child))\n",
    "            for grandchild in grandchildren:\n",
    "                G.add_edge(node, grandchild)\n",
    "            G.remove_node(child)\n",
    "\n",
    "G = nx.DiGraph(G.subgraph(cats.keys()))\n",
    "sorted_keys = list(nx.topological_sort(G))\n",
    "cats = {k: cats[k] for k in sorted_keys}"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "include_colab_link": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "add_tokens",
   "language": "python",
   "name": "add_tokens"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
