{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "sys.path.append(\"..\")\n",
    "import copy\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, operators, utils, functional, metrics, lens\n",
    "from src.utils import logging_utils, experiment_utils\n",
    "import logging\n",
    "import torch\n",
    "import baukit\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "experiment_utils.set_seed(123456)\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format = logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################################\n",
    "h_layer = 7\n",
    "beta = 2.5\n",
    "n_training = 10\n",
    "#################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(name = \"gptj\", fp16 = True, device = \"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = data.load_dataset().filter(relation_names=[\"country capital city\"])[0].set(prompt_templates=[\" {}:\"])\n",
    "train, test = relation.split(n_training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icl_prompt = functional.make_prompt(\n",
    "    prompt_template = train.prompt_templates[0],\n",
    "    subject = \"{}\",\n",
    "    examples = train.samples,\n",
    "    mt = mt\n",
    ")\n",
    "print(icl_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt = mt, \n",
    "    test_relation=test,\n",
    "    prompt_template = icl_prompt,\n",
    "    batch_size=4\n",
    ")\n",
    "len(test.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = operators.JacobianIclMeanEstimator(\n",
    "    mt = mt, h_layer=h_layer, beta=beta\n",
    ")\n",
    "operator = estimator(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dict = operator.__dict__.copy()\n",
    "operator_dict[\"beta\"] = 1.0\n",
    "no_beta = operators.LinearRelationOperator(**operator_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for sample in test.samples:\n",
    "    pred = operator(sample.subject).predictions[:3]\n",
    "    no_beta_pred = no_beta(sample.subject).predictions[:3]\n",
    "    print(f\"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]} | no_beta: {[f'{p.token} ({p.prob:.2f})' for p in no_beta_pred]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for approx in operator.metadata[\"approxes\"]:\n",
    "    h = approx.h\n",
    "    weight = approx.weight\n",
    "    bias = approx.bias\n",
    "    print(f\"{h.norm()=:.3f} | {weight.norm()=:.3f} | {bias.norm()=:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mythical_estimator = operators.JacobianIclMeanEstimator_Imaginary(\n",
    "    mt = mt, h_layer=h_layer, beta=1.0, magnitude_h=65.0\n",
    ")\n",
    "mythical_operator = mythical_estimator(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for approx in mythical_operator.metadata[\"approxes\"]:\n",
    "    h = approx.h\n",
    "    weight = approx.weight\n",
    "    bias = approx.bias\n",
    "    print(f\"{h.norm()=:.3f} | {weight.norm()=:.3f} | {bias.norm()=:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"{operator.weight.norm()=:.3f} | {mythical_operator.weight.norm()=:.3f}\")\n",
    "print(f\"{operator.bias.norm()=:.3f} | {mythical_operator.bias.norm()=:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = operator.metadata[\"approxes\"][0].h\n",
    "s2 = operator.metadata[\"approxes\"][1].h\n",
    "\n",
    "j_delta_h = operator.weight @ (s1 - s2)\n",
    "myth_j_delta_h = mythical_operator.weight @ (s1 - s2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cosine_similarity(j_delta_h, myth_j_delta_h, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cosine_similarity(operator.bias, mythical_operator.bias, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fixing the hparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imaginary_operators = []\n",
    "for interpolate_on in tqdm(range(2, 8)):\n",
    "    estimator_i = operators.JacobianIclMeanEstimator_Imaginary(\n",
    "        mt = mt, h_layer=7, beta=1, interpolate_on=interpolate_on, n_trials=8, magnitude_h=65.0\n",
    "    )\n",
    "    operator_i = estimator_i(train)\n",
    "    imaginary_operators.append(operator_i)\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator.weight.norm().item(), operator.bias.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w_norms = [op.weight.norm().item() for op in imaginary_operators]\n",
    "plt.plot(range(2, 8), w_norms, label = \"|| J_imaginary ||\")\n",
    "plt.hlines(operator.weight.norm().item(), 2, 7, color=\"red\", label = \"|| J_real ||\")\n",
    "plt.ylim(bottom=0)\n",
    "plt.legend()\n",
    "plt.ylabel(\"|| J ||\")\n",
    "plt.xlabel(\"n_points\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator_dict = imaginary_operators[2].__dict__.copy()\n",
    "operator_dict[\"beta\"] = 1\n",
    "img_operator = operators.LinearRelationOperator(**operator_dict)\n",
    "\n",
    "print(img_operator.weight.norm().item(), img_operator.bias.norm().item())\n",
    "\n",
    "for sample in test.samples:\n",
    "    pred = img_operator(sample.subject).predictions[:3]\n",
    "    print(f\"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_norms = [op.bias.norm().item() for op in imaginary_operators]\n",
    "plt.plot(range(2, 8), b_norms, label = \"|| bias_imaginary ||\")\n",
    "plt.hlines(operator.bias.norm().item(), 2, 7, color=\"red\", label = \"|| bias_real ||\")\n",
    "plt.ylim(bottom=200)\n",
    "plt.legend()\n",
    "plt.ylabel(\"|| bias ||\")\n",
    "plt.xlabel(\"n_points\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for sample in test.samples:\n",
    "    pred = operator_i(sample.subject).predictions[:3]\n",
    "    print(f\"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
