{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "import random\n",
    "import transformers\n",
    "\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from src.relations import estimate\n",
    "from src.util import model_utils\n",
    "from baukit import nethook\n",
    "from operator import itemgetter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"EleutherAI/gpt-j-6B\"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B\n",
    "mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)\n",
    "\n",
    "model = mt.model\n",
    "tokenizer = mt.tokenizer\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "print(f\"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################################\n",
    "relation_id = \"P17\"\n",
    "precision_at = 3\n",
    "#################################################\n",
    "\n",
    "with open(\"../data/counterfact.json\") as f:\n",
    "    counterfact = json.load(f)\n",
    "\n",
    "objects = [c['requested_rewrite'] for c in counterfact if c[\"requested_rewrite\"]['relation_id'] == relation_id]\n",
    "objects = [\" \"+ o['target_true']['str'] for o in objects]\n",
    "objects = list(set(objects))\n",
    "print(\"unique objects: \", len(objects), objects[0:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.relations.corner import CornerEstimator\n",
    "corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)\n",
    "print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)\n",
    "print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_dsc_corner = corner_estimator.estimate_corner_with_gradient_descent(objects, target_logit_value=50, verbose=True)\n",
    "print(grad_dsc_corner.norm().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)\n",
    "print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def P17__check_with_test_cases(relation_operator):\n",
    "\n",
    "    test_cases = [\n",
    "        (\"The Great Wall\", -1, \"China\"),\n",
    "        (\"Niagara Falls\", -2, \"Canada\"),\n",
    "        (\"Valdemarsvik\", -1, \"Sweden\"),\n",
    "        (\"Kyoto University\", -2, \"Japan\"),\n",
    "        (\"Hattfjelldal\", -1, \"Norway\"),\n",
    "        (\"Ginza\", -1, \"Japan\"),\n",
    "        (\"Sydney Hospital\", -2, \"Australia\"),\n",
    "        (\"Mahalangur Himal\", -1, \"Nepal\"),\n",
    "        (\"Higashikagawa\", -1, \"Japan\"),\n",
    "        (\"Trento\", -1, \"Italy\"),\n",
    "        (\"Taj Mahal\", -1, \"India\"),\n",
    "        (\"Hagia Sophia\", -1, \"Turkey\"),\n",
    "        (\"Colosseum\", -1, \"Italy\"),\n",
    "        (\"Mount Everest\", -1, \"Nepal\"),\n",
    "        (\"Valencia\", -1, \"Spain\"),\n",
    "        (\"Lake Baikal\", -1, \"Russia\"),\n",
    "        (\"Merlion Park\", -1, \"Singapore\"),\n",
    "        (\"Cologne Cathedral\", -1, \"Germany\"),\n",
    "        (\"Buda Castle\", -1, \"Hungary\")\n",
    "    ]\n",
    "\n",
    "    for subject, subject_token_index, target in test_cases:\n",
    "        objects = relation_operator(\n",
    "            subject,\n",
    "            subject_token_index=subject_token_index,\n",
    "            device=model.device,\n",
    "            return_top_k=5,\n",
    "        )\n",
    "        print(f\"{subject}, target: {target}   ==>   predicted: {objects}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} is located in the country of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = simple_corner\n",
    ")\n",
    "P17__check_with_test_cases(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} is located in the country of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = lin_inv_corner\n",
    ")\n",
    "P17__check_with_test_cases(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} is located in the country of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = grad_dsc_corner\n",
    ")\n",
    "P17__check_with_test_cases(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = estimate.RelationOperator(\n",
    "    model = model,\n",
    "    tokenizer = tokenizer,\n",
    "    relation = '{} is located in the country of',\n",
    "    layer = 15,\n",
    "    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),\n",
    "    bias = avg_corner\n",
    ")\n",
    "P17__check_with_test_cases(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relation",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3439fe3f7dcaddaf51997811d25ada8e7c0985d2997d22a3ed461af94d2f9f43"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
