{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import torch\n",
    "import re\n",
    "import einops\n",
    "import pandas as pd\n",
    "from functools import partial\n",
    "from torch import Tensor\n",
    "from torchtyping import TensorType as TT\n",
    "\n",
    "\n",
    "import plotly.express as px\n",
    "\n",
    "from utils.data_utils import generate_data_and_caches\n",
    "from utils.data_processing import (\n",
    "    load_edge_scores_into_dictionary,\n",
    ")\n",
    "from utils.visualization import plot_attention_heads, imshow_p\n",
    "from utils.backup_analysis import (\n",
    "    load_model,\n",
    "    run_iteration,\n",
    "    process_backup_results,\n",
    "    get_past_nmhs_for_checkpoints,\n",
    "    plot_top_heads\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = \"cuda\"\n",
    "else:\n",
    "    device = \"cpu\"\n",
    "\n",
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Functions"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiments"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiment Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = 'ioi'\n",
    "PERFORMANCE_METRIC = 'logit_diff'\n",
    "BASE_MODEL = \"pythia-70m\"\n",
    "VARIANT = None #\"EleutherAI/pythia-70m-weight-seed3\"\n",
    "MODEL_SHORTNAME = BASE_MODEL if not VARIANT else VARIANT[11:]\n",
    "CACHE = \"model_cache\"\n",
    "IOI_DATASET_SIZE = 70\n",
    "COPY_SCORE_THRESHOLD = 75.0"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Circuit Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_path = f'results/graphs/{MODEL_SHORTNAME}/{TASK}'\n",
    "df = load_edge_scores_into_dictionary(folder_path)\n",
    "\n",
    "# filter everything before 1000 steps\n",
    "df = df[df['checkpoint'] >= 1000]\n",
    "\n",
    "df[['source', 'target']] = df['edge'].str.split('->', expand=True)\n",
    "len(df['target'].unique())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_model = load_model(BASE_MODEL, VARIANT, 143000, CACHE, device)\n",
    "size=70\n",
    "ioi_dataset, abc_dataset = generate_data_and_caches(initial_model, size, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imshow_p(\n",
    "#     per_head_ablated_logit_diffs,\n",
    "#     title=\"Headwise logit diff contribution, post NMH KO\",\n",
    "#     labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff attribution\"},\n",
    "#     #coloraxis=dict(colorbar_ticksuffix = \"%\"),\n",
    "#     border=True,\n",
    "#     width=600,\n",
    "#     margin={\"r\": 100, \"l\": 100}\n",
    "# )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics = dict()\n",
    "# create folder\n",
    "os.makedirs(f'results/backup/{MODEL_SHORTNAME}', exist_ok=True)\n",
    "\n",
    "for checkpoint in range(4000, 144000, 1000):\n",
    "\n",
    "    experiment_metrics = run_iteration(\n",
    "        BASE_MODEL, VARIANT, df, checkpoint=checkpoint, dataset=ioi_dataset, experiment_metrics=experiment_metrics, \n",
    "        threshold=COPY_SCORE_THRESHOLD\n",
    "    )\n",
    "    experiment_metrics = process_backup_results(df, checkpoint, experiment_metrics)\n",
    "\n",
    "    # save to file, using pytorch format\n",
    "    torch.save(experiment_metrics, f'results/backup/{MODEL_SHORTNAME}/nmh_backup_metrics.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics.keys()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## View Results"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pythia 160m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_TO_VIEW = \"pythia-160m-alldropout\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics = torch.load(f'results/backup/{MODEL_TO_VIEW}/nmh_backup_metrics.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics[4000].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summed_in_circuit_head_deltas = {checkpoint: experiment_metrics[checkpoint][\"summed_in_circuit_head_delta\"] for checkpoint in experiment_metrics.keys()}\n",
    "summed_outside_circuit_head_deltas = {checkpoint: experiment_metrics[checkpoint][\"summed_outside_circuit_head_delta\"] for checkpoint in experiment_metrics.keys()}\n",
    "summed_total_head_deltas = {checkpoint: experiment_metrics[checkpoint][\"summed_total_head_delta\"] for checkpoint in experiment_metrics.keys()}\n",
    "per_head_logit_diff_deltas = {checkpoint: experiment_metrics[checkpoint][\"per_head_logit_diff_delta\"] for checkpoint in experiment_metrics.keys()}\n",
    "total_logit_diff_deltas = {checkpoint: experiment_metrics[checkpoint]['ablated_logit_diff'] - experiment_metrics[checkpoint]['logit_diff'] for checkpoint in experiment_metrics.keys()}\n",
    "\n",
    "for checkpoint in experiment_metrics.keys():\n",
    "    # divide by total original logit diff\n",
    "    summed_in_circuit_head_deltas[checkpoint] = summed_in_circuit_head_deltas[checkpoint] / experiment_metrics[checkpoint][\"logit_diff\"]\n",
    "    summed_outside_circuit_head_deltas[checkpoint] = summed_outside_circuit_head_deltas[checkpoint] / experiment_metrics[checkpoint][\"logit_diff\"]\n",
    "    summed_total_head_deltas[checkpoint] = summed_total_head_deltas[checkpoint] / experiment_metrics[checkpoint][\"logit_diff\"]\n",
    "    per_head_logit_diff_deltas[checkpoint] = per_head_logit_diff_deltas[checkpoint] / experiment_metrics[checkpoint][\"logit_diff\"]\n",
    "    total_logit_diff_deltas[checkpoint] = total_logit_diff_deltas[checkpoint] / experiment_metrics[checkpoint][\"logit_diff\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot summed_in_circuit_head_deltas with plotly express\n",
    "fig = px.line(\n",
    "    x=list(summed_in_circuit_head_deltas.keys()), \n",
    "    y=list(summed_in_circuit_head_deltas.values()), \n",
    "    title=f\"Summed Post-NMH-Ablation In-Circuit Head Logit Diff Change Over Time ({MODEL_TO_VIEW})\",\n",
    "    labels={'x': 'Checkpoint', 'y': 'Change as % of original logit diff'} \n",
    ")\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot summed_outside_circuit_head_deltas\n",
    "fig = px.line(\n",
    "    x=list(summed_outside_circuit_head_deltas.keys()), \n",
    "    y=list(summed_outside_circuit_head_deltas.values()), \n",
    "    title=f\"Summed Post-NMH-Ablation Outside-Circuit Head Attribution Change ({MODEL_TO_VIEW})\",\n",
    "    labels={'x': 'Checkpoint', 'y': 'Change as % of original logit diff'} \n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot total_head_deltas\n",
    "fig = px.line(\n",
    "    x=list(summed_total_head_deltas.keys()), \n",
    "    y=list(summed_total_head_deltas.values()), \n",
    "    title=f\"Summed Total Post-NMH-Ablation Head Attribution Change ({MODEL_TO_VIEW})\",\n",
    "    labels={'x': 'Checkpoint', 'y': 'Change as % of original logit diff'}\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cumulative_nmhs, checkpoint_nmhs = get_past_nmhs_for_checkpoints(experiment_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_backup_heads = plot_top_heads(model_name=MODEL_TO_VIEW, checkpoint_dict=per_head_logit_diff_deltas, cumulative_nmhs=cumulative_nmhs, top_k_per_checkpoint=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#per_head_logit_diff_deltas\n",
    "\n",
    "imshow_p(\n",
    "    experiment_metrics[143000]['per_head_logit_diff_delta'], #[143000],\n",
    "    title=\"Headwise logit diff contribution, post NMH KO\",\n",
    "    labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff attribution\"},\n",
    "    #coloraxis=dict(colorbar_ticksuffix = \"%\"),\n",
    "    border=True,\n",
    "    width=600,\n",
    "    margin={\"r\": 100, \"l\": 100}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_metrics[143000].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_backup_heads[top_backup_heads['Previous NMH']==True].head(50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_nmhs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cumulative_nmhs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot number of nmhs over time\n",
    "fig = px.line(\n",
    "    x=list(checkpoint_nmhs.keys()), \n",
    "    y=list([len(heads) for heads in checkpoint_nmhs.values()]), \n",
    "    title=f\"Number of NMHs Over Time ({MODEL_TO_VIEW})\",\n",
    "    labels={'x': 'Checkpoint', 'y': 'Number of NMHs'}\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
