{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "import networkx as nx\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import hierarchical as hrc\n",
    "import einops\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "MODEL_NAME = \"google/gemma-2b\"\n",
    "g, _, sqrt_Cov_gamma = hrc.get_g(MODEL_NAME, device)\n",
    "vocab_dict, vocab_list = hrc.get_vocab(MODEL_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "cats, G, sorted_keys = hrc.get_categories('noun', 'gemma')\n",
    "vec_reps = torch.load('FILE_PATH')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logit_diff(features, vector):\n",
    "    inds = {cat: [vocab_dict[t] for t in cats[cat]] for cat in features}\n",
    "    unembed = {cat: g[inds[cat]] for cat in features}\n",
    "\n",
    "    diff = unembed[features[1]].unsqueeze(1) -  unembed[features[0]].unsqueeze(0)\n",
    "    logits = diff @ vector\n",
    "    return logits.flatten()\n",
    "\n",
    "\n",
    "def show_logit_diff(parents, children):\n",
    "    vector = vec_reps['original']['non_split'][parents[1]]['lda'] -  vec_reps['original']['non_split'][parents[0]]['lda']\n",
    "    vector = vector / torch.norm(vector)\n",
    "\n",
    "    child_diff = get_logit_diff(children, vector)\n",
    "    parent_diff = get_logit_diff(parents, vector)\n",
    "\n",
    "    print(f\"child: {child_diff.mean().item():.4f} \\\\pm {child_diff.std().item():.4f}\")\n",
    "    print(f\"parent: {parent_diff.mean().item():.4f} \\\\pm {parent_diff.std().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "child: -0.0600 \\pm 1.2190\n",
      "parent: 5.1265 \\pm 1.1731\n"
     ]
    }
   ],
   "source": [
    "parents = ['plant.n.02', 'animal.n.01']\n",
    "children = ['mammal.n.01', 'reptile.n.01']\n",
    "\n",
    "show_logit_diff(parents, children)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "child: 0.3770 \\pm 1.5410\n",
      "parent: 9.8296 \\pm 1.1099\n"
     ]
    }
   ],
   "source": [
    "parents = ['fluid.n.02', 'solid.n.01']\n",
    "children = ['crystal.n.01', 'food.n.02']\n",
    "\n",
    "show_logit_diff(parents, children)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "child: -0.1545 \\pm 1.1426\n",
      "parent: 14.4222 \\pm 0.9458\n"
     ]
    }
   ],
   "source": [
    "parents = ['scientist.n.01', 'contestant.n.01']\n",
    "children = ['athlete.n.01', 'player.n.01']\n",
    "\n",
    "show_logit_diff(parents, children)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "add_tokens",
   "language": "python",
   "name": "add_tokens"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
