{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import glob\n",
    "import re\n",
    "import sys\n",
    "from typing import Dict, List\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "import plotly.io as pio\n",
    "import chart_studio\n",
    "from chart_studio import plotly as py\n",
    "\n",
    "from IPython.display import display, clear_output, HTML\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from utils.data_processing import (\n",
    "    load_edge_scores_into_dictionary,\n",
    "    compute_weighted_jaccard_similarity,\n",
    "    compute_weighted_jaccard_similarity_to_reference,\n",
    "    compute_ewma_weighted_jaccard_similarity,\n",
    "    generate_in_circuit_df_files\n",
    ")\n",
    "from utils.result_plotting import plot_head_circuit_scores, plot_graph_metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function\n",
    "\n",
    "# load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma\n",
    "# with open(\"../auth/chart_studio_api_key.txt\") as f:\n",
    "#     username, api_key = f.read().strip().split(\",\")\n",
    "#     # strip leading whitespace\n",
    "#     username = username.strip()\n",
    "#     api_key = api_key.strip()\n",
    "\n",
    "# chart_studio.tools.set_credentials_file(username=username, api_key=api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = 'ioi'\n",
    "PERFORMANCE_METRIC = 'logit_diff'\n",
    "MODEL_NAME = 'pythia-160m'\n",
    "OUTPUT_DIR = f\"../results/plots/component_metrics/{TASK}/\"\n",
    "\n",
    "# create output directory\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Circuit Components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "components_over_time = torch.load(f'../results/components/{MODEL_NAME}/components_over_time.pt')\n",
    "heads_over_time = torch.load(f'../results/components/{MODEL_NAME}/heads_over_time.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpts = list(components_over_time.keys())\n",
    "ckpts.sort()\n",
    "#ckpts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_scores = dict()\n",
    "filtered_copy_scores = dict()\n",
    "io_attns = dict()\n",
    "io_s1_attn_ratio = dict()\n",
    "copy_suppression_scores = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['direct_effect_scores'] is not None:\n",
    "        copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']\n",
    "        filtered_copy_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_scores']\n",
    "        io_attns[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores']\n",
    "        io_s1_attn_ratio[ckpt] = components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores']\n",
    "        copy_suppression_scores[ckpt] = components_over_time[ckpt]['direct_effect_scores']['copy_suppression_scores']\n",
    "\n",
    "\n",
    "#io_attns = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] for ckpt in ckpts}\n",
    "#io_s1_attn_ratio = {ckpt: components_over_time[ckpt]['direct_effect_scores']['io_attn_scores'] / components_over_time[ckpt]['direct_effect_scores']['s1_attn_scores'] for ckpt in ckpts}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### NMH Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_nmh = set()\n",
    "for ckpt in ckpts:\n",
    "    all_nmh.update(heads_over_time[ckpt]['nmh'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "all_heads_copy_score = plot_head_circuit_scores(\n",
    "    copy_scores, \n",
    "    show_legend=False, \n",
    "    title= f'Copy Score Across Checkpoints ({MODEL_NAME})',\n",
    "    log_x=True, \n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_heads_io_attn = plot_head_circuit_scores(\n",
    "    io_s1_attn_ratio, \n",
    "    title= f'IO:S1 Attn Ratio Across Checkpoints ({MODEL_NAME})', \n",
    "    limit_to_list=all_nmh, \n",
    "    range_y=[0, 20],\n",
    "    log_x=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Copy Suppression Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_suppression_scores_df = plot_head_circuit_scores(\n",
    "    copy_suppression_scores, \n",
    "    show_legend=False, \n",
    "    title= f'Copy Suppression Scores Across Checkpoints ({MODEL_NAME})',\n",
    "    log_x=True, \n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### S2I Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_signal_importance = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['s2i_scores'] is not None:\n",
    "        pos_signal_importance[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_ablated_logit_diff_deltas']['token_same_pos_oppo']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_signal_df = plot_head_circuit_scores(\n",
    "    pos_signal_importance, \n",
    "    show_legend=False, \n",
    "    title= f'S2I Pos Signal Ablation Logit Diff Change % Across Checkpoints ({MODEL_NAME})', \n",
    "    log_x=True,\n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_signal_io_attn_change = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['s2i_scores'] is not None:\n",
    "        pos_signal_io_attn_change[ckpt] = components_over_time[ckpt]['s2i_scores']['s2i_io_attention_deltas']['token_same_pos_oppo']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_signal_io_attn_df = plot_head_circuit_scores(\n",
    "    pos_signal_io_attn_change, \n",
    "    show_legend=False, \n",
    "    title= f'Effect of S2I Pos Signal Ablation On NMH IO Attn ({MODEL_NAME})', \n",
    "    log_x=True,\n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tertiary Component Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "components_over_time[137000]['tertiary_head_scores'].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "induction_scores = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['tertiary_head_scores'] is not None:\n",
    "        induction_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['induction_scores']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "induction_df = plot_head_circuit_scores(\n",
    "    induction_scores, \n",
    "    show_legend=False, \n",
    "    title= f'Induction Scores Across Checkpoints ({MODEL_NAME})', \n",
    "    log_x=True,\n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_token_scores = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['tertiary_head_scores'] is not None:\n",
    "        prev_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['prev_token_scores']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_token_df = plot_head_circuit_scores(\n",
    "    prev_token_scores, \n",
    "    show_legend=False, \n",
    "    title= f'Prev Token Scores Across Checkpoints ({MODEL_NAME})', \n",
    "    log_x=True,\n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "duplicate_token_scores = dict()\n",
    "for ckpt in ckpts:\n",
    "    if components_over_time[ckpt]['tertiary_head_scores'] is not None:\n",
    "        duplicate_token_scores[ckpt] = components_over_time[ckpt]['tertiary_head_scores']['duplicate_token_scores']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "duplicate_token_df = plot_head_circuit_scores(\n",
    "    duplicate_token_scores, \n",
    "    show_legend=False, \n",
    "    title= f'Duplicate Token Scores Across Checkpoints ({MODEL_NAME})', \n",
    "    log_x=True,\n",
    "    disable_title=True,\n",
    "    output_path = OUTPUT_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
