{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pathlib\n",
    "import inflect\n",
    "\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "\n",
    "from collections import defaultdict\n",
    "from semantic_memory import taxonomy\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's say we are interested in PaliGemma\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/paligemma-3b-mix-224\")\n",
    "\n",
    "def check_in_vocab(word):\n",
    "    with_space = tokenizer.tokenize(f\" {word}\", add_special_tokens=False)\n",
    "    no_space = tokenizer.tokenize(word, add_special_tokens=False)\n",
    "    if len(with_space) == 1 or len(no_space) == 1:\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "    \n",
    "p = inflect.engine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_json(filename):\n",
    "    with open(filename, 'r') as f:\n",
    "        return json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "noun_hypernyms = read_json(\"../data/gqa_entities/noun-hypernyms.json\")\n",
    "final_entities = read_json(\"../data/gqa_entities/entity_set.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "hypernym_paths = defaultdict(set)\n",
    "\n",
    "for noun, hypernyms in noun_hypernyms.items():\n",
    "    hypernym_paths[noun].add(tuple(hypernyms))\n",
    "    # each hypernym is the child of the next one\n",
    "    for i in range(len(hypernyms) - 1):\n",
    "        hypernym_paths[hypernyms[i]].add(tuple(hypernyms[i + 1:]))\n",
    "\n",
    "# store only the longest paths\n",
    "longest_paths = {}\n",
    "for noun, paths in hypernym_paths.items():\n",
    "    longest_paths[noun] = max(paths, key=len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure they are in the model\n",
    "# longest_paths_model = {}\n",
    "# for noun, path in longest_paths.items():\n",
    "#     final_path = []\n",
    "#     if check_in_vocab(noun):\n",
    "#         for concept in path:\n",
    "#             if check_in_vocab(concept):\n",
    "#                 final_path.append(concept)\n",
    "#         # if len(final_path) >= 1:\n",
    "#         longest_paths_model[noun] = final_path\n",
    "\n",
    "\n",
    "# now store the unique hypernym pairs\n",
    "hypernym_pairs = {}\n",
    "for noun, path in longest_paths.items():\n",
    "    try:\n",
    "        hypernym_pairs[noun] = path[0]\n",
    "    except:\n",
    "        print(noun, path)\n",
    "\n",
    "def get_hypernym_model(word):\n",
    "    try:\n",
    "        hypernym = hypernym_pairs[word]\n",
    "    except:\n",
    "        hypernym = \"entity\"\n",
    "    if not check_in_vocab(hypernym):\n",
    "        hypernym = get_hypernym_model(hypernym)\n",
    "    return hypernym\n",
    "\n",
    "final_pairs = {}\n",
    "for noun in set(hypernym_pairs.keys()).union(set(hypernym_pairs.values())):\n",
    "    if check_in_vocab(noun):\n",
    "        final_pairs[noun] = get_hypernym_model(noun)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'device'"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_pairs['monitor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a tree with \"entity\" as its root node.\n",
    "Tree = taxonomy.Nodeset(taxonomy.Node)\n",
    "root = Tree['entity']\n",
    "\n",
    "# # populate the tree\n",
    "\n",
    "for concept, hypernym in final_pairs.items():\n",
    "    if concept in final_entities and hypernym in final_entities:\n",
    "        node = Tree[concept]\n",
    "        parent = Tree[hypernym]\n",
    "        node.add_parent(parent)\n",
    "        parent.add_child(node)\n",
    "\n",
    "# make sure root is added as a parent to all top level nodes\n",
    "for value, node in Tree.items():\n",
    "    if value == \"entity\":\n",
    "        continue\n",
    "    elif node.parent is None:\n",
    "        node.add_parent(root)\n",
    "        root.add_child(node)\n",
    "\n",
    "Tree.default_factory = None # to make sure we dont accidentally add more nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save data in a way that park et al. did for their paper, but for paligemma:\n",
    "G = nx.DiGraph()\n",
    "\n",
    "for entry, node in Tree.items():\n",
    "    path = node.path()\n",
    "    if len(path) > 1:\n",
    "        G.add_edge(node.parent.value, entry)\n",
    "\n",
    "# I am going to skip the merging since our taxonomy is already pretty smol\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# taken directly from the park et al. code.\n",
    "vocab = tokenizer.get_vocab()\n",
    "vocab_set = set(vocab.keys())\n",
    "\n",
    "def _noun_to_gemma_vocab_elements(word):\n",
    "    word = word.lower()\n",
    "    plural = p.plural(word)\n",
    "    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]\n",
    "    add_space = [\"▁\" + w for w in add_cap_and_plural]\n",
    "    return vocab_set.intersection(add_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving convention: data/taxonomies/custom/<model_family>\n",
    "\n",
    "path = \"../data/taxonomies/custom/gemma/\"\n",
    "pathlib.Path(path).mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "with open(f'{path}/items.json', 'w') as f:\n",
    "    for key, node in Tree.items():\n",
    "        words = []\n",
    "        for w in node.descendants():\n",
    "            words.extend(_noun_to_gemma_vocab_elements(w.value))\n",
    "\n",
    "        f.write(json.dumps({key : words}) + \"\\n\")\n",
    "    \n",
    "nx.write_adjlist(G, f\"{path}/hypernym_graph.adjlist\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
