{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0fccc94",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcefd854",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af057892",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "\n",
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5856dc82",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6c0dce0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import data\n",
    "\n",
    "dataset = data.load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a470cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in dataset:\n",
    "    print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9da567e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import baukit\n",
    "import torch\n",
    "from src.functional import Order1ApproxOutput\n",
    "from src.utils.misc import visualize_matrix\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "@torch.inference_mode(mode=False)\n",
    "def order_1_approx(\n",
    "    *,\n",
    "    mt: models.ModelAndTokenizer,\n",
    "    prompt: str,\n",
    "    h_layer: int,\n",
    "    h_index: int,\n",
    "    z_token_indices: list[int],\n",
    "    z_layer: int | None = None,\n",
    "    z_index: int | None = None,\n",
    "    inputs=None,\n",
    "):\n",
    "    if z_layer is None:\n",
    "        z_layer = mt.model.config.n_layer - 1\n",
    "    if z_index is None:\n",
    "        z_index = -1\n",
    "    if inputs is None:\n",
    "        inputs = mt.tokenizer(prompt, return_tensors=\"pt\").to(mt.model.device)\n",
    "\n",
    "    # Precompute everything up to the subject, if there is anything before it.\n",
    "    past_key_values = None\n",
    "    input_ids = inputs.input_ids\n",
    "    _h_index = h_index\n",
    "    if _h_index > 0:\n",
    "        outputs = mt.model(input_ids=input_ids[:, :_h_index], use_cache=True)\n",
    "        past_key_values = outputs.past_key_values\n",
    "        input_ids = input_ids[:, _h_index:]\n",
    "        _h_index = 0\n",
    "    use_cache = past_key_values is not None\n",
    "\n",
    "    # Precompute initial h and z.\n",
    "    [h_layer_name, z_layer_name] = models.determine_layer_paths(mt, [h_layer, z_layer])\n",
    "    with baukit.TraceDict(mt.model, (h_layer_name, z_layer_name)) as ret:\n",
    "        outputs = mt.model(\n",
    "            input_ids=input_ids,\n",
    "            use_cache=use_cache,\n",
    "            past_key_values=past_key_values,\n",
    "        )\n",
    "    h = ret[h_layer_name].output[0][0, _h_index]\n",
    "    z = ret[z_layer_name].output[0][0, z_index]\n",
    "\n",
    "    # Now compute J and b.\n",
    "    def compute_z_from_h(h: torch.Tensor) -> torch.Tensor:\n",
    "        def insert_h(output: tuple, layer: str) -> tuple:\n",
    "            if layer != h_layer_name:\n",
    "                return output\n",
    "            output[0][0, _h_index] = h\n",
    "            return output\n",
    "\n",
    "        with baukit.TraceDict(\n",
    "            mt.model, (h_layer_name, z_layer_name), edit_output=insert_h\n",
    "        ) as ret:\n",
    "            mt.model(\n",
    "                input_ids=input_ids,\n",
    "                past_key_values=past_key_values,\n",
    "                use_cache=use_cache,\n",
    "            )\n",
    "        z = ret[z_layer_name].output[0][0, -1]\n",
    "        z = mt.model.transformer.ln_f(z)\n",
    "\n",
    "        hidden_size = mt.model.config.hidden_size\n",
    "\n",
    "        # proj = z.new_zeros(hidden_size, hidden_size)\n",
    "        # for z_token_index in z_token_indices:\n",
    "        #     y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]\n",
    "        #     proj += y @ y.t() / y.norm().pow(2)\n",
    "        Y = []\n",
    "        for z_token_index in z_token_indices:\n",
    "            y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]\n",
    "            Y.append(y.T[0])\n",
    "        Y = torch.stack(Y, dim=1).to(torch.float32)\n",
    "        # proj = Y @ (Y.T @ Y).to(torch.float32).pinverse().to(Y.dtype) @ Y.T\n",
    "        proj = Y @ (Y.T @ Y).to(torch.float32).pinverse() @ Y.T\n",
    "        print(torch.linalg.matrix_rank(proj))\n",
    "        proj = proj.to(mt.model.dtype)\n",
    "        # print(proj)\n",
    "        # visualize_matrix(proj)\n",
    "        result = proj @ z[..., None]\n",
    "        print(Y.norm().item(), proj.norm().item(), result.norm().item())\n",
    "\n",
    "        # raise AssertionError()\n",
    "        return result.squeeze()\n",
    "\n",
    "    weight = torch.autograd.functional.jacobian(compute_z_from_h, h, vectorize=True)\n",
    "    bias = z[None] - h[None].mm(weight.t())\n",
    "    approx = Order1ApproxOutput(\n",
    "        h=h,\n",
    "        h_layer=h_layer,\n",
    "        h_index=h_index,\n",
    "        z=z,\n",
    "        z_layer=z_layer,\n",
    "        z_index=z_index,\n",
    "        weight=weight,\n",
    "        bias=bias,\n",
    "        inputs=inputs.to(\"cpu\"),\n",
    "        logits=outputs.logits.cpu(),\n",
    "    )\n",
    "    return approx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5a54169",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt=\"Eiffle Tower is located in the city of\"\n",
    "# tokenized = mt.tokenizer(prompt, return_tensors=\"pt\").to(mt.model.device)\n",
    "# print([(t.item(), mt.tokenizer.decode(t)) for t in tokenized.input_ids[0]])\n",
    "\n",
    "# output = order_1_approx(\n",
    "#     mt=mt,\n",
    "#     prompt=\"Eiffle Tower is located in the city of\",    \n",
    "#     h_layer=15,\n",
    "#     h_index=3,\n",
    "#     z_layer=27,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a39cfa83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "from src import data, functional, operators\n",
    "from src.utils import tokenizer_utils\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "def get_icl_prompt(samples, sample, prompt_template):\n",
    "    others = list(set(samples) - {sample})\n",
    "    prompt = \"\\n\".join(\n",
    "        prompt_template.format(x.subject) + f\" {x.object}.\"\n",
    "        for x in others\n",
    "    )\n",
    "    prompt += \"\\n\" + prompt_template.format(sample.subject)\n",
    "    return prompt\n",
    "\n",
    "\n",
    "class NewJEstimator(operators.LinearRelationEstimator):\n",
    "\n",
    "    mt: models.ModelAndTokenizer\n",
    "    h_layer: int = 9\n",
    "    z_layer: int = 27\n",
    "\n",
    "    def __call__(self, relation):\n",
    "        prompt_templates = relation.prompt_templates[:1]\n",
    "        samples = relation.samples[:3]\n",
    "\n",
    "        targets = [x for x in relation.range]\n",
    "        z_token_indices = self.mt.tokenizer(\n",
    "            targets,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=True,\n",
    "        ).input_ids[:, 0].tolist()\n",
    "    \n",
    "        weights = []\n",
    "        biases = []\n",
    "        zs = []\n",
    "        hs = []\n",
    "        for prompt_template in prompt_templates:\n",
    "            for sample in tqdm(samples):\n",
    "                subject = sample.subject\n",
    "#                 prompt = prompt_template.format(subject)\n",
    "                prompt = get_icl_prompt(samples, sample, prompt_template)\n",
    "                print(prompt, \"\\n\", \"---\")\n",
    "                _, h_index = tokenizer_utils.find_token_range(prompt, subject, tokenizer=mt.tokenizer)\n",
    "                h_index -= 1\n",
    "\n",
    "                # output = functional.order_1_approx(\n",
    "                #     mt=self.mt,\n",
    "                #     prompt=prompt,\n",
    "                #     h_layer=self.h_layer,\n",
    "                #     h_index=h_index,\n",
    "                #     z_layer=self.z_layer,\n",
    "                # )\n",
    "                output = order_1_approx(\n",
    "                    mt=self.mt,\n",
    "                    prompt=prompt,\n",
    "                    h_layer=self.h_layer,\n",
    "                    h_index=h_index,\n",
    "                    z_token_indices=z_token_indices,\n",
    "                    z_layer=self.z_layer,\n",
    "                )\n",
    "                weights.append(output.weight)\n",
    "                biases.append(output.bias)\n",
    "        \n",
    "                hs.append(output.h)\n",
    "                zs.append(output.z)\n",
    "\n",
    "#         weight = weights[0]\n",
    "        weight = torch.stack(weights).mean(dim=0)\n",
    "#         weight = torch.eye(len(weight)).to(weight.device, weight.dtype)\n",
    "\n",
    "        bias = torch.stack(biases).mean(dim=0)\n",
    "        print(bias.norm())\n",
    "        bias = bias * .5\n",
    "        print(bias.norm())\n",
    "\n",
    "        print(\"h norm\", torch.stack(hs).norm(dim=-1).squeeze().mean())\n",
    "        print(\"Jh norm\", torch.stack([weight @ h for h in hs]).norm(dim=-1).squeeze().mean())\n",
    "        print(\"Jh + b norm\", torch.stack([weight @ h + bias * 2 for h in hs]).norm(dim=-1).squeeze().mean())\n",
    "        print(\"Jh + b/2 norm\", torch.stack([weight @ h + bias for h in hs]).norm(dim=-1).squeeze().mean())\n",
    "        print(\"z norm\", torch.stack(zs).norm(dim=-1).squeeze().mean())\n",
    "\n",
    "#         bias = mt.model.transformer.wte.weight.data[z_token_indices].mean(dim=0)\n",
    "\n",
    "#         hidden_size = weight.shape[0]\n",
    "#         proj = bias.new_zeros(hidden_size, hidden_size)\n",
    "#         for z_token_index in z_token_indices:\n",
    "#             y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]\n",
    "#             proj += y @ y.t() / y.norm().pow(2)\n",
    "#         weight = weight @ proj\n",
    "\n",
    "#         bias = torch.zeros_like(bias)\n",
    "\n",
    "        return operators.LinearRelationOperator(\n",
    "            mt=self.mt,\n",
    "            weight=weight,\n",
    "            bias=bias,\n",
    "            h_layer=self.h_layer,\n",
    "            z_layer=self.z_layer,\n",
    "            prompt_template=relation.prompt_templates[0],\n",
    "        )\n",
    "\n",
    "\n",
    "estimator = NewJEstimator(mt=mt)\n",
    "\n",
    "relation = dataset[0].set(prompt_templates=[dataset[0].prompt_templates[0]])\n",
    "# relation = dataset[1].set(prompt_templates=[\"People in {} speak the language of\"])\n",
    "\n",
    "# relation = data.Relation(\n",
    "#     name=\"workplaces\",\n",
    "#     prompt_templates=[\"{} typically work inside of a\"],\n",
    "#     samples=[\n",
    "#         data.RelationSample(\"Nurses\", \"hospital\"),\n",
    "#         data.RelationSample(\"Judges\", \"courtroom\"),\n",
    "#         data.RelationSample(\"Farmers\", \"field\"),\n",
    "#         data.RelationSample(\"Car mechanics\", \"garage\"),\n",
    "#         data.RelationSample(\"Teachers\", \"classroom\"),\n",
    "#     ],\n",
    "# )\n",
    "\n",
    "# relation = data.Relation(\n",
    "#     name=\"color\",\n",
    "#     prompt_templates=[\"{} are typically associated with the color\"],\n",
    "#     samples=[\n",
    "#         data.RelationSample(\"Bananas\", \"yellow\"),\n",
    "#         data.RelationSample(\"Kiwis\", \"green\"),\n",
    "#         data.RelationSample(\"Potatoes\", \"brown\"),\n",
    "#     ],\n",
    "#     _range=[\n",
    "#         \"pink\",\n",
    "#         \"yellow\",\n",
    "#         \"red\",\n",
    "#         \"green\",\n",
    "#         \"blue\",\n",
    "#         \"orange\",\n",
    "#         \"violet\",\n",
    "#         \"magenta\",\n",
    "#         \"brown\",\n",
    "#         \"black\",\n",
    "#         \"white\",\n",
    "#         \"purple\",\n",
    "#         \"grey\",\n",
    "#         \"gray\",\n",
    "#         \"maroon\",\n",
    "#     ]\n",
    "# )\n",
    "# relation = data.Relation.from_dict({\n",
    "#     \"name\": \"president elected 1900s\",\n",
    "#     \"prompt_templates\": [\n",
    "#         \"{} was elected president in the year\"\n",
    "#     ],\n",
    "#     \"samples\": [\n",
    "#         {\n",
    "#             \"subject\": \"John F. Kennedy\",\n",
    "#             \"object\": \"1960\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Lyndon B. Johnson\",\n",
    "#             \"object\": \"1963\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Richard Nixon\",\n",
    "#             \"object\": \"1968\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"James Carter\",\n",
    "#             \"object\": \"1977\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Ronald Reagan\",\n",
    "#             \"object\": \"1980\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"George H. W. Bush\",\n",
    "#             \"object\": \"1988\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Bill Clinton\",\n",
    "#             \"object\": \"1992\"\n",
    "#         }\n",
    "#     ]\n",
    "# })\n",
    "\n",
    "# relation = data.Relation.from_dict({\n",
    "#     \"name\": \"president born 1900s\",\n",
    "#     \"prompt_templates\": [\n",
    "#         \"{} was born in the year\"\n",
    "#     ],\n",
    "#     \"samples\": [\n",
    "#         {\n",
    "#             \"subject\": \"John F. Kennedy\",\n",
    "#             \"object\": \"1917\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Lyndon B. Johnson\",\n",
    "#             \"object\": \"1908\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Richard Nixon\",\n",
    "#             \"object\": \"1913\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"James Carter\",\n",
    "#             \"object\": \"1924\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Ronald Reagan\",\n",
    "#             \"object\": \"1911\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"George H. W. Bush\",\n",
    "#             \"object\": \"1924\"\n",
    "#         },\n",
    "#         {\n",
    "#             \"subject\": \"Bill Clinton\",\n",
    "#             \"object\": \"1946\"\n",
    "#         }\n",
    "#     ]\n",
    "# })\n",
    "\n",
    "with torch.device(device):\n",
    "    operator = estimator(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a75cadee",
   "metadata": {},
   "outputs": [],
   "source": [
    "operator(\"India\", k=20).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d418b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.corner import CornerEstimator\n",
    "\n",
    "corner_estimator = CornerEstimator(mt.model, mt.tokenizer)\n",
    "corner = corner_estimator.estimate_corner_with_gradient_descent(\n",
    "    target_words = list(relation.range),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca586d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "corner_operator = operators.LinearRelationOperator(\n",
    "    mt = operator.mt,\n",
    "    weight = operator.weight,\n",
    "    bias = corner/5, ## setting bias = corner\n",
    "    h_layer = operator.h_layer,\n",
    "    z_layer = operator.z_layer,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subject_token_offset = operator.subject_token_offset,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aab7572",
   "metadata": {},
   "outputs": [],
   "source": [
    "corner_operator(\"United States\", k=20).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "882d50e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated = mt.model.generate(\n",
    "    mt.tokenizer(\"The capital of Pakistan is\", padding=True, return_tensors=\"pt\").input_ids.to(device),\n",
    ")\n",
    "\n",
    "mt.tokenizer.decode(generated[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d26bc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.heatmap(operator.weight[300:375, 300:375].cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe086740",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def complete(prompt):\n",
    "    inputs = mt.tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    outputs = mt.model(**inputs)\n",
    "    top5 = torch.log_softmax(outputs.logits, dim=-1)[:, -1].topk(k=20, dim=-1).indices.squeeze().tolist()\n",
    "    return [mt.tokenizer.decode(x) for x in top5]\n",
    "\n",
    "complete(\"Bill Clinton was born in the year\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f315ed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0].prompt_templates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd4d70c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = operators.JacobianIclEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=5,\n",
    "    z_layer=27,\n",
    ")\n",
    "relation = dataset[0]\n",
    "with torch.device(device):\n",
    "    operator = estimator(relation.set(samples=relation.samples[10:15]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771d3c10",
   "metadata": {},
   "outputs": [],
   "source": [
    "operator(\"Spain\", k=10).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97628144",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
