{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcbd8171",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77e32fe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "\n",
    "device = \"cuda:5\"\n",
    "mt = models.load_model(\"gptj\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ead3b38",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import data\n",
    "\n",
    "dataset = data.load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "680990d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import operators\n",
    "\n",
    "estimator = operators.JacobianIclMeanEstimator(mt=mt, h_layer=11, z_layer=26)\n",
    "relation = dataset[0].set(\n",
    "    samples=dataset[0].samples[:3],\n",
    "    prompt_templates=[dataset[0].prompt_templates[0]],\n",
    ")\n",
    "operator = estimator(relation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91b6224a",
   "metadata": {},
   "outputs": [],
   "source": [
    "operator(\"Italy\").predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "598dc69f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import functional\n",
    "\n",
    "import baukit\n",
    "import torch\n",
    "\n",
    "def apply_lm_head(hiddens):\n",
    "    predictions = mt.lm_head[1:](hiddens).topk(dim=-1, k=10).indices\n",
    "    return [\n",
    "        [mt.tokenizer.decode(t) for t in ts]\n",
    "        for ts in predictions\n",
    "    ]\n",
    "\n",
    "\n",
    "outputs = functional.compute_hidden_states(mt=mt, layers=[26], prompt=[\n",
    "    \"The capital of Norway is\",\n",
    "    \"The capital of France is\",\n",
    "])\n",
    "z = outputs.hiddens[0][0, -1]\n",
    "z_old = outputs.hiddens[0][-1, -1]\n",
    "print(z)\n",
    "print(z_old)\n",
    "\n",
    "# j_inv = torch.linalg.pinv(operator.weight.float()).half()\n",
    "j_inv = torch.eye(4096).half().to(device)\n",
    "\n",
    "delta = j_inv.mm(z[..., None] - operator.bias.t())\n",
    "print(delta.norm())\n",
    "\n",
    "delta_old = j_inv.mm(z_old[..., None] - operator.bias.t())\n",
    "print(delta_old.norm())\n",
    "\n",
    "def edit_output(output, layer):\n",
    "    if output[0].shape[1] == 1:\n",
    "        return output\n",
    "    print(output[0][0, 3].norm())\n",
    "    output[0][0, 3] = (\n",
    "#         -1 * delta_old.squeeze() +\n",
    "        delta.squeeze()\n",
    "    )\n",
    "    print(output[0][0, 3].norm())\n",
    "    return output\n",
    "\n",
    "\n",
    "inputs = mt.tokenizer(\"The capital of France is the city of\", return_tensors=\"pt\").to(device)\n",
    "with baukit.TraceDict(mt.model, [f\"transformer.h.{l}\" for l in [11]], edit_output=edit_output):\n",
    "    outputs = mt.model.generate(\n",
    "        **inputs,\n",
    "        pad_token_id=mt.tokenizer.eos_token_id,\n",
    "        max_new_tokens=100,\n",
    "    )\n",
    "mt.tokenizer.batch_decode(outputs)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14d3a933",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def apply_lm_head(hiddens):\n",
    "    predictions = mt.lm_head[1:](hiddens).argmax(dim=-1)\n",
    "    print(predictions.shape)\n",
    "    return [(i, mt.tokenizer.decode(tok)) for i, tok in enumerate(predictions.tolist())]\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def logit_lens(prompt):\n",
    "    inputs = mt.tokenizer(prompt, return_tensors=\"pt\", padding=\"longest\", truncation=True).to(device)\n",
    "    inputs.pop(\"token_type_ids\", None)\n",
    "    outputs = mt.model(**inputs, return_dict=True, output_hidden_states=True)\n",
    "    hiddens = torch.cat(outputs.hidden_states[1:], dim=0)[:, -1]\n",
    "    return apply_lm_head(hiddens)\n",
    "\n",
    "\n",
    "logit_lens(\"The capital of France is\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ebff147",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
