{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "from src import models, data\n",
    "from src.attributelens.attributelens import Attribute_Lens\n",
    "import src.attributelens.utils as lens_utils\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dtype: torch.float16, device: cuda:0, memory: 12219206136\n"
     ]
    }
   ],
   "source": [
    "# LREs are caches for GPT-J. \n",
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device, fp16=True)\n",
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|endoftext|> The United States of America (U.S.A. or USA), commonly known as the United States'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# prompt = mt.tokenizer.eos_token + \" \" + \"present-day Turkey was home to important Neolithic sites like\"\n",
    "prompt =  mt.tokenizer.eos_token + \" \" + \"The United States of America (U.S.A. or USA), commonly known as the United States\"\n",
    "prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attribute Lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import LinearRelationOperator\n",
    "\n",
    "def load_cached_lre(relation_name, path = \"../results/LRE_cached\"):\n",
    "    approx = np.load(os.path.join(path, relation_name.replace(\" \", \"_\") + \".npz\"), allow_pickle=True)\n",
    "    approx_dict = {}\n",
    "    for key,value in approx.items():\n",
    "        if key in [\"h\", \"z\", \"weight\", \"bias\"]:\n",
    "            approx_dict[key] = torch.from_numpy(value).cuda()\n",
    "        else:\n",
    "            approx_dict[key] = value.item()\n",
    "    return LinearRelationOperator(\n",
    "        mt = mt, \n",
    "        weight = approx_dict[\"weight\"],\n",
    "        bias = approx_dict[\"bias\"],\n",
    "        h_layer = approx_dict[\"h_layer\"],\n",
    "        z_layer = approx_dict[\"z_layer\"],\n",
    "        prompt_template = approx_dict[\"prompt_template\"],\n",
    "        beta = approx_dict[\"beta\"]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment the block and print `relation_names` to see all the options\n",
    "# dataset = data.load_dataset()\n",
    "# relation_names = [r.name for r in dataset.relations]\n",
    "# relation_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_names = [\n",
    "    \"country capital city\",\n",
    "    \"country largest city\",\n",
    "    \"country currency\",\n",
    "    \"country language\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "lres = {\n",
    "    relation_name: load_cached_lre(relation_name = relation_name)\n",
    "    for relation_name in relation_names\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------\n",
      "country capital city  --  oranges\n",
      "----------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"1900px\"\n",
       "    height=\"595\"\n",
       "    src=\"iframe_figures/figure_9.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------\n",
      "country largest city  --  purples\n",
      "----------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"1900px\"\n",
       "    height=\"595\"\n",
       "    src=\"iframe_figures/figure_9.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------\n",
      "country currency  --  greens\n",
      "----------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"1900px\"\n",
       "    height=\"595\"\n",
       "    src=\"iframe_figures/figure_9.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------\n",
      "country language  --  reds\n",
      "----------------------------------------\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"1900px\"\n",
       "    height=\"595\"\n",
       "    src=\"iframe_figures/figure_9.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "lens = Attribute_Lens(mt=mt, top_k=10)\n",
    "\n",
    "colorscales = [\"oranges\", \"purples\", \"greens\", \"reds\"]\n",
    "\n",
    "for relation_name, colorscale in zip(relation_names, colorscales):\n",
    "    print(\"----------------------------------------\")\n",
    "    print(relation_name, \" -- \", colorscale)\n",
    "    print(\"----------------------------------------\")\n",
    "    att_info = lens.apply_attribute_lens(\n",
    "        prompt=prompt,\n",
    "        relation_operator=lres[relation_name]\n",
    "    )\n",
    "    att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token\n",
    "    p = lens_utils.visualize_attribute_lens(\n",
    "        att_info, layer_skip=2, must_have_layers=[],\n",
    "        colorscale= colorscale\n",
    "    )\n",
    "    p.layout.margin = dict(l=0, r=0, t=0, b=0)\n",
    "    p.show(renderer='iframe')\n",
    "    \n",
    "    time.sleep(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logit Lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe\n",
       "    scrolling=\"no\"\n",
       "    width=\"1900px\"\n",
       "    height=\"595\"\n",
       "    src=\"iframe_figures/figure_10.html\"\n",
       "    frameborder=\"0\"\n",
       "    allowfullscreen\n",
       "></iframe>\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lens = Attribute_Lens(mt=mt, top_k=10)\n",
    "att_info = lens.apply_attribute_lens(\n",
    "    prompt=prompt,\n",
    "    relation_operator=None # Will use Identity if set to None. Basically Logit Lens\n",
    ")\n",
    "att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token\n",
    "# print('prediction:', att_info['nextwords'][-1])\n",
    "p = lens_utils.visualize_attribute_lens(\n",
    "    att_info, layer_skip=2, must_have_layers=[],\n",
    ")\n",
    "p.show(renderer='iframe')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
