{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from src import models, data\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "os.makedirs(\"layer_sweep/Jacobian_plots\", exist_ok=True)\n",
    "os.makedirs(\"layer_sweep/weights_and_biases\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device)\n",
    "print(\n",
    "    f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = data.load_dataset()\n",
    "# capital_cities =[d for d in dataset if d.name == \"capital city\"][0]\n",
    "# print(capital_cities)\n",
    "# # capital_cities.__dict__.keys()\n",
    "# len(capital_cities.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# indices = np.random.choice(range(len(capital_cities.samples)), 3, replace=False)\n",
    "# samples = [capital_cities.samples[i] for i in indices]\n",
    "\n",
    "# training_samples = copy.deepcopy(capital_cities.__dict__)\n",
    "# training_samples[\"samples\"] = samples\n",
    "# training_samples = data.Relation(**training_samples)\n",
    "\n",
    "# training_samples.samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.operators import JacobianIclMeanEstimator\n",
    "\n",
    "# mean_estimator = JacobianIclMeanEstimator(\n",
    "#     mt=mt,\n",
    "#     h_layer=12,\n",
    "#     bias_scale_factor=0.2       # so that the bias doesn't knock out the prediction too much in the direction of training examples\n",
    "# ) \n",
    "\n",
    "# operator = mean_estimator(training_samples)\n",
    "# operator(\"United States\", k = 10).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.data import RelationDataset\n",
    "# from src.benchmarks import reconstruction, faithfulness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# faith_result = faithfulness(\n",
    "#     estimator=mean_estimator,\n",
    "#     dataset = RelationDataset([capital_cities]),\n",
    "#     n_train=3,\n",
    "#     n_trials=5,\n",
    "#     k=3\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for k, v in faith_result.metrics.__dict__.items():\n",
    "#     print(f\"{k}: {v}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# recos_result = reconstruction(\n",
    "#     estimator=mean_estimator,\n",
    "#     dataset = RelationDataset([capital_cities]),\n",
    "#     n_train=3,\n",
    "#     n_trials=5,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.lens import interpret_logits, logit_lens\n",
    "from src.functional import untuple\n",
    "\n",
    "prompt = \"The Space Needle is located in the city of\"\n",
    "tokenized = mt.tokenizer(prompt, return_tensors=\"pt\", padding=True).to(mt.model.device)\n",
    "\n",
    "import baukit\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model,\n",
    "    models.determine_layer_paths(mt)\n",
    ") as traces:\n",
    "    output = mt.model(**tokenized)\n",
    "    \n",
    "interpret_logits(mt, output.logits[0][-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interested_words = [\" Seattle\", \" Paris\", \" Dhaka\"]\n",
    "int_tokenized = mt.tokenizer(interested_words, return_tensors=\"pt\", padding=True).to(\n",
    "    mt.model.device\n",
    ")\n",
    "int_tokenized.input_ids\n",
    "\n",
    "z = untuple(traces[models.determine_layer_paths(mt)[-1]].output)[0][-1]\n",
    "print(z.shape)\n",
    "\n",
    "logit_lens(mt, z, [t[0] for t in int_tokenized.input_ids], get_proba=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "predict_next_token(\n",
    "    mt = mt, prompt = \"The Space Needle is located in the city of\", k=10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_by_model_knowledge(mt, relation_prompt, relation_samples):\n",
    "    model_knows = []\n",
    "    for sample in relation_samples:\n",
    "        top_prediction = predict_next_token(mt = mt, prompt = relation_prompt.format(sample.subject))[0][0].token\n",
    "        tick = sample.object.strip().startswith(top_prediction.strip())\n",
    "        if tick:\n",
    "            model_knows.append(sample)\n",
    "    return model_knows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "cur_relation = [d for d in dataset if d.name == \"capital city\"][0]\n",
    "print(cur_relation.name, \" ==> \", len(cur_relation.samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from src.functional import make_prompt\n",
    "\n",
    "icl_indices = np.random.choice(range(len(cur_relation.samples)), 3, replace=False)\n",
    "icl_samples = [cur_relation.samples[i] for i in icl_indices]\n",
    "\n",
    "icl_prompt = make_prompt(\n",
    "    prompt_template = cur_relation.prompt_templates[0],\n",
    "    subject=\"{}\",\n",
    "    examples=icl_samples,\n",
    ")\n",
    "\n",
    "print(icl_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_knows = filter_by_model_knowledge(mt, icl_prompt, cur_relation.samples)\n",
    "len(model_knows)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer Richness based on logit lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.lens import layer_c_measure\n",
    "\n",
    "# relation_prompt = mt.tokenizer.eos_token + \" {} is located in the city of\"\n",
    "# subject = \"The Space Needle\"\n",
    "# layer_c_measure(mt, relation_prompt, subject, measure = \"contribution\", verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_c_info = {layer: [] for layer in models.determine_layer_paths(mt)}\n",
    "\n",
    "for sample in tqdm(model_knows):\n",
    "    cur_richness = layer_c_measure(mt, icl_prompt, sample.subject)\n",
    "    for layer in models.determine_layer_paths(mt):\n",
    "        layer_c_info[layer].append(cur_richness[layer])\n",
    "\n",
    "# with open(\"layer_sweep/layer_contribution_info.json\", \"w\") as f:\n",
    "with open(\"layer_sweep/layer_completeness_info.json\", \"w\") as f:\n",
    "    json.dump(layer_c_info, f)\n",
    "\n",
    "for layer in models.determine_layer_paths(mt):\n",
    "    layer_c_info[layer] = np.array(layer_c_info[layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_richness = [layer_c_info[layer].mean() for layer in models.determine_layer_paths(mt)]\n",
    "low_richness = [layer_c_info[layer].min() for layer in models.determine_layer_paths(mt)]\n",
    "high_richness = [layer_c_info[layer].max() for layer in models.determine_layer_paths(mt)]\n",
    "\n",
    "plt.plot(mean_richness, color=\"blue\")\n",
    "plt.fill_between(range(len(mean_richness)), low_richness, high_richness, alpha=0.2)\n",
    "plt.axhline(0, color=\"red\", linestyle=\"--\")\n",
    "\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"completeness\")\n",
    "plt.xticks(range(0, len(mean_richness), 2))\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer Richness based on `Jh_norm` and `J_norm`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "cur_relation_known = copy.deepcopy(cur_relation.__dict__)\n",
    "cur_relation_known[\"samples\"] = model_knows\n",
    "\n",
    "cur_relation_known = data.Relation(**cur_relation_known)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import JacobianEstimator, JacobianIclMeanEstimator\n",
    "from src.data import RelationSample\n",
    "\n",
    "# indices = np.random.choice(range(len(capital_cities.samples)), 3, replace=False)\n",
    "# samples = [capital_cities.samples[i] for i in indices]\n",
    "\n",
    "# training_samples = copy.deepcopy(capital_cities.__dict__)\n",
    "# training_samples[\"samples\"] = samples\n",
    "# training_samples = data.Relation(**training_samples)\n",
    "\n",
    "# mean_estimator = JacobianIclMeanEstimator(\n",
    "#     mt=mt,\n",
    "#     h_layer=12,\n",
    "# )\n",
    "\n",
    "# operator = mean_estimator(training_samples)\n",
    "# operator(\"Russia\", k = 10).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = JacobianEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=12,\n",
    ")\n",
    "\n",
    "operator = estimator.estimate_for_subject(\n",
    "    subject = \"United States\",\n",
    "    prompt_template= icl_prompt\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator.metadata['Jh'].norm().item(), operator.weight.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layerwise_jh = {layer: [] for layer in models.determine_layer_paths(mt)}\n",
    "\n",
    "for sample in tqdm(set(model_knows) - set(icl_samples)):\n",
    "    for h_layer in range(0, 24):\n",
    "        layer_name = models.determine_layer_paths(mt)[h_layer]\n",
    "        estimator = JacobianEstimator(\n",
    "            mt=mt,\n",
    "            h_layer=h_layer,\n",
    "        )\n",
    "        operator = estimator.estimate_for_subject(\n",
    "            # subject = \"Russia\",\n",
    "            subject = sample.subject,\n",
    "            prompt_template= icl_prompt\n",
    "        )\n",
    "\n",
    "        # print(h_layer, \" ===> \", f\"J:{operator.weight.norm().item()},  Jh: {operator.misc['Jh'].norm().item()}\")\n",
    "        layerwise_jh[layer_name].append({\n",
    "            \"J\": operator.weight.norm().item(),\n",
    "            \"Jh\": operator.metadata['Jh'].norm().item(),\n",
    "            \"bias\": operator.bias.norm().item()\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in models.determine_layer_paths(mt):\n",
    "    if layer in layerwise_jh and len(layerwise_jh[layer]) == 0:\n",
    "        layerwise_jh.pop(layer)\n",
    "\n",
    "with open(\"layer_sweep/layer_jh_info.json\", \"w\") as f:\n",
    "    json.dump(layerwise_jh, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = \"Jh\"\n",
    "\n",
    "info = {\n",
    "    layer: np.array([layerwise_jh[layer][i][key] for i in range(len(layerwise_jh[layer]))])\n",
    "    for layer in layerwise_jh.keys()\n",
    "}\n",
    "\n",
    "mean = [info[layer].mean() for layer in info.keys()]\n",
    "plt.plot(mean, color=\"blue\", linewidth=4)\n",
    "plt.xticks(range(0, len(mean), 2))\n",
    "plt.ylabel(f\"{key}_norm\")\n",
    "\n",
    "for i in range(len(set(model_knows) - set(icl_samples))):\n",
    "    arr = []\n",
    "    for layer in layerwise_jh.keys():\n",
    "        arr.append(layerwise_jh[layer][i][key])\n",
    "    plt.plot(arr, alpha=0.2)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causal Tracing on `subject_last`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.layer_selection import causal_tracing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "causal_tracing(\n",
    "    mt,\n",
    "    prompt_template = \"The location of {} is in the city of\",\n",
    "    subject_original = \"The Space Needle\",\n",
    "    subject_corruption = \"The Statue of Liberty\",\n",
    "    # verbose = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "cur_relation_known = copy.deepcopy(cur_relation.__dict__)\n",
    "cur_relation_known[\"samples\"] = model_knows\n",
    "\n",
    "cur_relation_known = data.Relation(**cur_relation_known)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_icl = 3\n",
    "\n",
    "icl_indices = np.random.choice(range(len(cur_relation_known.samples)), num_icl, replace=False)\n",
    "icl_samples = [cur_relation.samples[i] for i in icl_indices]\n",
    "icl_prompt = [\n",
    "    f\"{cur_relation.prompt_templates[0].format(sample.subject)} {sample.object}\"\n",
    "    for sample in icl_samples\n",
    "]\n",
    "icl_prompt = \"\\n\".join(icl_prompt) + \"\\n\" + cur_relation.prompt_templates[0]\n",
    "\n",
    "print(icl_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_samples = set(cur_relation_known.samples) - set(icl_samples)\n",
    "causal_tracing_results = {layer: [] for layer in models.determine_layer_paths(mt)}\n",
    "\n",
    "n_runs = 20\n",
    "for run in tqdm(range(n_runs)):\n",
    "    sample_pair = np.random.choice(range(len(test_samples)), 2, replace=False)\n",
    "    sample_pair = [list(test_samples)[i] for i in sample_pair]\n",
    "    print(sample_pair)\n",
    "    \n",
    "    cur_result = causal_tracing(\n",
    "        mt,\n",
    "        prompt_template = icl_prompt,\n",
    "        subject_original = sample_pair[0].subject,\n",
    "        subject_corruption = sample_pair[1].subject,\n",
    "        verbose = False\n",
    "    )\n",
    "\n",
    "    for layer in models.determine_layer_paths(mt):\n",
    "        causal_tracing_results[layer].append(cur_result[layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"layer_sweep/causal_tracing_results.json\", \"w\") as f:\n",
    "    json.dump(causal_tracing_results, f)\n",
    "\n",
    "for layer in models.determine_layer_paths(mt):\n",
    "    causal_tracing_results[layer] = np.array(causal_tracing_results[layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models.determine_layers(mt)[::2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = [causal_tracing_results[layer].mean() for layer in models.determine_layer_paths(mt)]\n",
    "# low = [causal_tracing_results[layer].min() for layer in mt.layer_names]\n",
    "# high = [causal_tracing_results[layer].max() for layer in mt.layer_names]\n",
    "\n",
    "plt.plot(mean, color=\"blue\", linewidth=3)\n",
    "# plt.fill_between(range(len(mean)), low, high, alpha=0.2)\n",
    "plt.axhline(0, color=\"red\", linestyle=\"--\")\n",
    "\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"layer_score\")\n",
    "plt.xticks(models.determine_layers(mt)[::2])\n",
    "\n",
    "for run in range(n_runs):\n",
    "    arr = []\n",
    "    for layer in models.determine_layer_paths(mt):\n",
    "        arr.append(causal_tracing_results[layer][run])\n",
    "    plt.plot(arr, alpha=0.2)\n",
    "\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer sweep on mean ICL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "cur_relation_known = copy.deepcopy(cur_relation.__dict__)\n",
    "cur_relation_known[\"samples\"] = model_knows\n",
    "\n",
    "cur_relation_known = data.Relation(**cur_relation_known)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.random.choice(range(len(cur_relation_known.samples)), 2, replace=False)\n",
    "samples = [cur_relation_known.samples[i] for i in indices]\n",
    "\n",
    "capital_cities_subset = copy.deepcopy(cur_relation.__dict__)\n",
    "capital_cities_subset[\"samples\"] = samples\n",
    "capital_cities_subset = data.Relation(**capital_cities_subset)\n",
    "\n",
    "len(capital_cities_subset.samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import JacobianIclMeanEstimator\n",
    "\n",
    "mean_estimator = JacobianIclMeanEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=12,\n",
    ")\n",
    "\n",
    "operator = mean_estimator(capital_cities_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "operator(\"Chile\", k = 10).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = []\n",
    "target = []\n",
    "\n",
    "for sample in tqdm(set(cur_relation_known.samples)):\n",
    "    cur_predictions = operator(sample.subject, k = 5).predictions\n",
    "    predictions.append([\n",
    "        p.token for p in cur_predictions\n",
    "    ])\n",
    "    target.append(sample.object)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.metrics import recall\n",
    "\n",
    "recall(predictions, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.savez(\"layer_sweep/operator_weight.npz\", jacobian = operator.weight.detach().cpu().numpy(), allow_pickle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# j = np.load(\"layer_sweep/operator_weight.npz\", allow_pickle=True)[\"jacobian\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.dist(torch.tensor(j).to(device), operator.weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_layer_wise_recall(capital_cities_subset, verbose = True, save_weights = True):\n",
    "\n",
    "    layer_wise_recall = {}\n",
    "\n",
    "    layer_names = models.determine_layer_paths(mt)\n",
    "    for h_layer in tqdm(range(0, 24)):\n",
    "        layer_name = layer_names[h_layer]\n",
    "        mean_estimator = JacobianIclMeanEstimator(\n",
    "            mt=mt,\n",
    "            h_layer=h_layer,\n",
    "        )\n",
    "        operator = mean_estimator(capital_cities_subset)\n",
    "        if(save_weights):\n",
    "            np.savez(\n",
    "                f\"layer_sweep/weights_and_biases/{layer_name}.npz\", \n",
    "                jacobian = operator.weight.detach().cpu().numpy(),\n",
    "                bias = operator.bias.detach().cpu().numpy(), \n",
    "                allow_pickle=True\n",
    "            )\n",
    "\n",
    "        predictions = []\n",
    "        target = []\n",
    "\n",
    "        for sample in set(cur_relation_known.samples) - set(capital_cities_subset.samples):\n",
    "            cur_predictions = operator(sample.subject, k = 5).predictions\n",
    "            predictions.append([\n",
    "                p.token for p in cur_predictions\n",
    "            ])\n",
    "            target.append(sample.object)\n",
    "\n",
    "        layer_wise_recall[layer_name] = recall(predictions, target)\n",
    "        \n",
    "        if(verbose):\n",
    "            print(layer_name, layer_wise_recall[layer_name])\n",
    "    \n",
    "    return layer_wise_recall\n",
    "\n",
    "layer_wise_recall = get_layer_wise_recall(capital_cities_subset, verbose = True, save_weights = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"layer_sweep/layer_wise_recall.json\", \"w\") as f:\n",
    "    json.dump(layer_wise_recall, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_wise_recall_collection = {}\n",
    "number_of_runs = 10\n",
    "\n",
    "for run in tqdm(range(number_of_runs)):\n",
    "    indices = np.random.choice(range(len(cur_relation_known.samples)), 2, replace=False)\n",
    "    samples = [cur_relation_known.samples[i] for i in indices]\n",
    "\n",
    "    capital_cities_subset = copy.deepcopy(cur_relation.__dict__)\n",
    "    capital_cities_subset[\"samples\"] = samples\n",
    "    capital_cities_subset = data.Relation(**capital_cities_subset)\n",
    "\n",
    "    layer_wise_recall = get_layer_wise_recall(capital_cities_subset, verbose=False, save_weights=False)\n",
    "\n",
    "    for layer in layer_wise_recall.keys():\n",
    "        if(layer not in layer_wise_recall_collection):\n",
    "            layer_wise_recall_collection[layer] = []\n",
    "        layer_wise_recall_collection[layer].append(layer_wise_recall[layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"layer_sweep/layer_wise_recall_collection.json\", \"w\") as f:\n",
    "    json.dump(layer_wise_recall_collection, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"layer_sweep/layer_wise_recall_collection.json\") as f:\n",
    "    layer_wise_recall_collection = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# top_1 = [layer_wise_recall[layer][0] for layer in layer_wise_recall.keys()]\n",
    "# top_2 = [layer_wise_recall[layer][1] for layer in layer_wise_recall.keys()]\n",
    "# top_3 = [layer_wise_recall[layer][2] for layer in layer_wise_recall.keys()]\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "top_1 = np.array([\n",
    "    np.array(layer_wise_recall_collection[layer])[:, 0]\n",
    "    for layer in layer_wise_recall_collection.keys()\n",
    "])\n",
    "\n",
    "top_2 = np.array([\n",
    "    np.array(layer_wise_recall_collection[layer])[:, 1]\n",
    "    for layer in layer_wise_recall_collection.keys()\n",
    "])\n",
    "\n",
    "top_3 = np.array([\n",
    "    np.array(layer_wise_recall_collection[layer])[:, 2]\n",
    "    for layer in layer_wise_recall_collection.keys()\n",
    "])\n",
    "\n",
    "\n",
    "plt.plot(top_1.mean(axis=1), color=\"green\", linewidth=3, label=\"recall@1\")\n",
    "plt.plot(top_2.mean(axis=1), color=\"blue\", linewidth=2, label=\"recall@2\")\n",
    "plt.plot(top_3.mean(axis=1), color=\"red\", linewidth=1, label=\"recall@3\")\n",
    "\n",
    "plt.fill_between(\n",
    "    range(len(layer_wise_recall_collection.keys())),\n",
    "    top_1.min(axis=1), top_1.max(axis=1),\n",
    "    color=\"green\", alpha=0.1\n",
    ")\n",
    "\n",
    "plt.fill_between(\n",
    "    range(len(layer_wise_recall_collection.keys())),\n",
    "    top_2.min(axis=1), top_2.max(axis=1),\n",
    "    color=\"blue\", alpha=0.05\n",
    ")\n",
    "\n",
    "plt.fill_between(\n",
    "    range(len(layer_wise_recall_collection.keys())),\n",
    "    top_3.min(axis=1), top_3.max(axis=1),\n",
    "    color=\"red\", alpha=0.03\n",
    ")\n",
    "\n",
    "\n",
    "plt.xticks(range(0, len(top_1), 2))\n",
    "plt.xlabel(\"layer\")\n",
    "plt.ylabel(\"recall\")\n",
    "\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.viz_utils import matrix_heatmap\n",
    "\n",
    "for layer_name in models.determine_layer_paths(mt)[:24]:\n",
    "    j = np.load(f\"layer_sweep/weights_and_biases/{layer_name}.npz\", allow_pickle=True)[\"jacobian\"]\n",
    "    j = torch.tensor(j).to(device)\n",
    "    print(layer_name, j.shape)\n",
    "    matrix_heatmap(j, title = layer_name, save_path=f\"layer_sweep/Jacobian_plots/{layer_name}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
