{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "sys.path.append(\"..\")\n",
    "import copy\n",
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, operators, utils, functional, metrics, lens\n",
    "from src.utils import logging_utils, experiment_utils\n",
    "import logging\n",
    "import torch\n",
    "import baukit\n",
    "\n",
    "experiment_utils.set_seed(123456)\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.INFO,\n",
    "    format = logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.interpolation import save_order_1_approx, normalize_on_sphere"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Interpolation Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(name = \"gptj\", fp16 = True, device = \"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_name = \"name birthplace\"\n",
    "relation = data.load_dataset().filter(relation_names=[relation_name])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_cached_values(trial_path, idx):\n",
    "    approx = np.load(f\"{trial_path}/approx_{idx+1}.npz\", allow_pickle=True)\n",
    "    approx_dict = {}\n",
    "    for key,value in approx.items():\n",
    "        if key in [\"h\", \"z\", \"weight\", \"bias\"]:\n",
    "            approx_dict[key] = torch.from_numpy(value).cuda()\n",
    "        else:\n",
    "            approx_dict[key] = value\n",
    "    weight = torch.Tensor(approx['weight']).cuda()\n",
    "    h = torch.Tensor(approx['h']).cuda()\n",
    "    z = torch.Tensor(approx['z']).cuda()\n",
    "\n",
    "    return weight, h, z, approx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import JacobianIclMeanEstimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute_line_integral_experiment(trial_path, interpolation_steps, lre_W = None, h_layer = None):\n",
    "    jacobians, hs, zs = [], [], []\n",
    "    for idx in tqdm(range(interpolation_steps)):\n",
    "        j, h, z, approx  = load_cached_values(trial_path, idx)\n",
    "        jacobians.append(j)\n",
    "        hs.append(h)\n",
    "        zs.append(z)\n",
    "\n",
    "        if h_layer is None:\n",
    "            h_layer = approx[\"h_layer\"].item()\n",
    "    \n",
    "    u = zs[-1] - zs[0]\n",
    "    v = hs[-1] - hs[0]\n",
    "    u /= u.norm()\n",
    "    v /= v.norm()\n",
    "\n",
    "    dh = (hs[-1] - hs[0])/interpolation_steps\n",
    "\n",
    "    scalar_values = []\n",
    "    Dz = zs[-1] - zs[0]\n",
    "    Dz_approx = 0\n",
    "    for dF in jacobians:\n",
    "        s = u[None] @ dF @ v[None].T\n",
    "        scalar_values.append(s.item())\n",
    "        Dz_approx += s * dh\n",
    "\n",
    "    if lre_W is None:\n",
    "        estimator = JacobianIclMeanEstimator(mt = mt, h_layer=h_layer)\n",
    "        train, test = relation.split(train_size=5)\n",
    "        operator = estimator(train)\n",
    "        lre_W = operator.weight.float()\n",
    "\n",
    "    Dz_lre = (u[None] @ lre_W @ v[None].T) * (hs[-1] - hs[0]).norm()\n",
    "\n",
    "    return Dz.norm().item(), Dz_approx.norm().item(), Dz_lre.norm().item(), scalar_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# s1, s2 = \"Brazil\", \"Spain\"\n",
    "\n",
    "\n",
    "experiment_utils.set_seed(123456)\n",
    "\n",
    "path = f\"../results/interpolation/{relation_name}/\"\n",
    "trials = os.listdir(path)\n",
    "\n",
    "actual_delta_zs = []\n",
    "approx_delta_zs = []\n",
    "lre_delta_zs = []\n",
    "scalar_values = []\n",
    "\n",
    "for trial in trials: #[f\"{s1}-{s2}\"]:\n",
    "    print(trial)\n",
    "    trial_path = os.path.join(path, trial)\n",
    "    interpolation_steps = len(os.listdir(trial_path))\n",
    "    Dz, Dz_approx, Dz_lre, values_i = execute_line_integral_experiment(trial_path, interpolation_steps)\n",
    "\n",
    "    actual_delta_zs.append(Dz)\n",
    "    approx_delta_zs.append(Dz_approx)\n",
    "    lre_delta_zs.append(Dz_lre)\n",
    "    scalar_values.append(values_i)\n",
    "\n",
    "    plt.plot(np.linspace(0, 1, interpolation_steps), values_i, label=trial)\n",
    "    print(f\"True Dz: {Dz:.4f} | Approx Dz: {Dz_approx:.4f} | LRE Dz: {Dz_lre:.4f}\")\n",
    "\n",
    "\n",
    "plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "plt.xlabel(\"Interpolation ratio, $\\\\alpha$\")\n",
    "plt.ylabel(r\"$u^T \\times F^'(s) \\times v$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = torch.Tensor(approx_delta_zs)/torch.Tensor(lre_delta_zs)\n",
    "f\"beta: {beta.mean():.3f}+/-{beta.std():.3f}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_dir = \"figures/figs\"\n",
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "color = \"purple\"\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 12\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 18\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "#####################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scalar_values_mean = torch.Tensor(scalar_values).mean(dim=0)\n",
    "scalar_values_std = torch.Tensor(scalar_values).std(dim=0)\n",
    "\n",
    "lre_delta_mean = torch.Tensor(lre_delta_zs).mean()\n",
    "lre_delta_std = torch.Tensor(lre_delta_zs).std()\n",
    "\n",
    "plt.plot(np.linspace(0, 1, interpolation_steps), scalar_values_mean, color=\"purple\", label = r\"$u^T \\; F^'(s) \\; \\Delta s$\")\n",
    "plt.fill_between(np.linspace(0, 1, interpolation_steps), scalar_values_mean - scalar_values_std, scalar_values_mean + scalar_values_std, color=\"purple\", alpha=0.1)\n",
    "plt.axhline(y=lre_delta_mean/100, color='r', linestyle='--', label=r\"$u^T \\; W \\; \\Delta s$\")\n",
    "\n",
    "plt.legend()\n",
    "plt.xlabel(\"Interpolation ratio, $\\\\alpha$\")\n",
    "plt.xlim(0, 1)\n",
    "plt.ylim(bottom=0)\n",
    "\n",
    "plt.savefig(f\"{fig_dir}/line_integration.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict\n",
    "\n",
    "sweep_path = f\"../results/sweep-24-trials/gptj\"\n",
    "relation_result_raw = read_sweep_results(\n",
    "    sweep_path, relation_names=[relation_name], economy=True\n",
    ")[relation_name]\n",
    "relation_result = relation_from_dict(relation_result_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_result.best_by_efficacy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading Cached $\\mathcal{W}$ and $b$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# relation = data.load_dataset().filter(relation_names=[\"country capital city\"])[0].set(prompt_templates=[\" {}:\"])\n",
    "# train, test = relation.split(5)\n",
    "\n",
    "# #################################################\n",
    "# h_layer = 8\n",
    "# beta = 2.25\n",
    "# interpolation_steps = 100\n",
    "# #################################################\n",
    "\n",
    "# mt = models.load_model(name = \"gptj\", fp16 = True, device = \"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path = \"../results/interpolation/country capital city/\"\n",
    "# W_norms = []\n",
    "# B_norms = []\n",
    "# Jh_norms = []\n",
    "# J_delta_h = []\n",
    "\n",
    "# values = []\n",
    "# labels = []\n",
    "\n",
    "\n",
    "# for trial in os.listdir(path):\n",
    "#     # if(trial == \"Colombia-Venezuela\"):\n",
    "#     #     continue\n",
    "#     if(trial != f\"{s1}-{s2}\"):\n",
    "#         continue\n",
    "\n",
    "#     s1, s2 = trial.split(\"-\")\n",
    "#     hs_and_zs = functional.compute_hs_and_zs(\n",
    "#         mt = mt,\n",
    "#         prompt_template = train.prompt_templates[0],\n",
    "#         subjects = list(set([sample.subject for sample in relation.samples] + [s1, s2])),\n",
    "#         h_layer= h_layer,\n",
    "#         z_layer=-1,\n",
    "#         examples= train.samples\n",
    "#     )\n",
    "#     h1, h2 = [hs_and_zs.h_by_subj[s] for s in [s1, s2]]\n",
    "#     z1, z2 = [hs_and_zs.z_by_subj[s] for s in [s1, s2]]\n",
    "\n",
    "#     do = (z2 - z1).float()\n",
    "#     ds = (h2 - h1).float()\n",
    "\n",
    "#     trial_path = os.path.join(path, trial)\n",
    "#     approxes = []\n",
    "#     interpolation_steps = len(os.listdir(trial_path))\n",
    "#     for idx in tqdm(range(interpolation_steps)):\n",
    "#         approx = np.load(f\"{trial_path}/approx_{idx+1}.npz\", allow_pickle=True)\n",
    "#         approx_dict = {}\n",
    "#         for key,value in approx.items():\n",
    "#             if key in [\"h\", \"z\", \"weight\", \"bias\"]:\n",
    "#                 approx_dict[key] = torch.from_numpy(value).cuda()\n",
    "#             else:\n",
    "#                 approx_dict[key] = value\n",
    "#         approxes.append(approx_dict)\n",
    "\n",
    "#     for a in approxes:\n",
    "#         weight = torch.Tensor(a['weight']).float().cuda()\n",
    "\n",
    "#         # scalar_value = torch.cosine_similarity(weight @ ds, do, dim = -1)\n",
    "#         # scalar_value = torch.dot(weight @ ds, do)\n",
    "#         # G = torch.tensor([torch.dot(do, weight[i]) for i in range(models.determine_hidden_size(mt))]).cuda()\n",
    "#         # scalar_value = torch.dot(G, ds)\n",
    "#         scalar_value = do[None] @ (weight @ ds[None].T)\n",
    "#         values.append(scalar_value.item())\n",
    "\n",
    "\n",
    "\n",
    "#     # w_norms = [torch.Tensor(a['weight']).norm().item() for a in approxes]\n",
    "#     # b_norms = [torch.Tensor(a['bias']).norm().item() for a in approxes]\n",
    "#     # jh_norms = [torch.Tensor(a[\"metadata\"].item()[\"Jh\"]).norm().item() for a in approxes]\n",
    "\n",
    "#     # h1 = normalize_on_sphere(approxes[0]['h'], scale=60.0)\n",
    "#     # h2 = normalize_on_sphere(approxes[-1]['h'], scale=60.0)\n",
    "#     # delta_h = normalize_on_sphere(h2 - h1, scale=60.0)\n",
    "#     # # delta_h = hs_and_zs.h_by_subj[\"Russia\"]\n",
    "#     # j_delta_h = [(torch.Tensor(a['weight']) @ delta_h).norm().item() for a in approxes]\n",
    "\n",
    "#     # W_norms.append(w_norms)\n",
    "#     # B_norms.append(b_norms)\n",
    "#     # Jh_norms.append(jh_norms)\n",
    "#     # J_delta_h.append(j_delta_h)\n",
    "#     # labels.append(trial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcdefaults()\n",
    "\n",
    "# plt.plot(np.linspace(0, 1, interpolation_steps), values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcdefaults()\n",
    "\n",
    "# for trial, label in zip(W_norms, labels):\n",
    "#     plt.plot(np.linspace(0, 1, interpolation_steps), trial, label=label)\n",
    "# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "# plt.ylabel(\"|| $W$ ||\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # filter_trials = [\"South Korea-Brazil\", \"Brazil-Nigeria\", \"Egypt-South Korea\", \"Chile-Egypt\"]\n",
    "# filter_trials = [\"Brazil-Turkey\", \"Turkey-France\", \"Brazil-France\", \"Spain-Egypt\"]\n",
    "# trial_results = []\n",
    "# for trial, label in zip(W_norms, labels):\n",
    "#     if label not in filter_trials:\n",
    "#         continue\n",
    "#     trial_results.append(trial)\n",
    "#     # plt.plot(np.linspace(0, 1, interpolation_steps), trial, label=label)\n",
    "# # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "\n",
    "# trial_results = np.array(trial_results)\n",
    "# trial_results_mean = trial_results.mean(axis=0)\n",
    "# trial_results_std = trial_results.std(axis=0)\n",
    "\n",
    "\n",
    "# fig_dir = \"figures/figs\"\n",
    "# #####################################################################################\n",
    "# plt.rcdefaults()\n",
    "# color = \"purple\"\n",
    "# plt.rcParams[\"figure.dpi\"] = 200\n",
    "# plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "# SMALL_SIZE = 12\n",
    "# MEDIUM_SIZE = 16\n",
    "# BIGGER_SIZE = 18\n",
    "\n",
    "# plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "# plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "# plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "# plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "# plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "# plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "# #####################################################################################\n",
    "\n",
    "# x = np.linspace(0, 1, interpolation_steps)\n",
    "# plt.plot(x, trial_results_mean, color=color)\n",
    "# plt.fill_between(x, trial_results_mean - trial_results_std, trial_results_mean + trial_results_std, alpha=0.1, color = color)\n",
    "# plt.xlabel(\"Interpolation ratio, $\\\\alpha$\")\n",
    "# plt.ylabel(\"|| $W$ ||\")\n",
    "\n",
    "# plt.savefig(f\"{fig_dir}/j_underestimation.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # J(h2 - h1)\n",
    "# plt.rcdefaults()\n",
    "# for trial, label in zip(J_delta_h, labels):\n",
    "#     plt.plot(np.linspace(0, 1, interpolation_steps), trial, label=label)\n",
    "# plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n",
    "# plt.ylabel(\"|| $W @ \\Delta h$ ||\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for trial in Jh_norms:\n",
    "#     plt.plot(np.linspace(0, 1, interpolation_steps), trial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
