{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "import random\n",
    "import transformers\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from relations import estimate\n",
    "from util import model_utils\n",
    "from baukit import nethook\n",
    "from operator import itemgetter\n",
    "from relations.evaluate import evaluate\n",
    "from relations.corner import CornerEstimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# counterfact = CounterFactDataset(\"../data/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"EleutherAI/gpt-j-6B\"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B\n",
    "mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)\n",
    "\n",
    "model = mt.model\n",
    "tokenizer = mt.tokenizer\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "print(f\"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################################\n",
    "relation_id = \"P101\"\n",
    "precision_at = 3\n",
    "#################################################\n",
    "\n",
    "with open(\"../data/counterfact.json\") as f:\n",
    "    counterfact = json.load(f)\n",
    "\n",
    "objects = [c['requested_rewrite'] for c in counterfact if c[\"requested_rewrite\"]['relation_id'] == relation_id]\n",
    "objects = [\" \"+ o['target_true']['str'] for o in objects]\n",
    "objects = list(set(objects))\n",
    "print(\"unique objects: \", len(objects), objects[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)\n",
    "print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))\n",
    "\n",
    "relation_operator = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} works in the field of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = simple_corner\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precision, ret_dict = evaluate(\n",
    "    relation_id= relation_id,\n",
    "    relation_operator= relation_operator,\n",
    "    precision_at=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)\n",
    "print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))\n",
    "\n",
    "relation_lin_inv = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} works in the field of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = lin_inv_corner\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "precision, ret_dict_2 = evaluate(\n",
    "    relation_id=\"P17\",\n",
    "    relation_operator= relation_lin_inv,\n",
    "    precision_at=3,\n",
    "    validation_set= ret_dict[\"validation_set\"]\n",
    ")\n",
    "\n",
    "precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relation",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3439fe3f7dcaddaf51997811d25ada8e7c0985d2997d22a3ed461af94d2f9f43"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
