{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import glob\n",
    "import re\n",
    "from typing import Dict, List\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "import plotly.io as pio\n",
    "import chart_studio\n",
    "from chart_studio import plotly as py\n",
    "\n",
    "from transformer_lens import HookedTransformer\n",
    "\n",
    "from IPython.display import display, clear_output, HTML\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "from utils.data_processing import (\n",
    "    load_edge_scores_into_dictionary,\n",
    "    compute_weighted_jaccard_similarity,\n",
    "    compute_weighted_jaccard_similarity_to_reference,\n",
    "    compute_ewma_weighted_jaccard_similarity,\n",
    ")\n",
    "\n",
    "from utils.result_plotting import format_cspa_data_for_plots, plot_nmh_metrics, plot_head_circuit_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # ignore the following if not using chart_studio. if you do want to publish graphs, simple include upload=True in the plotly function\n",
    "\n",
    "# # load API key from local file chart_studio_api_key.txt - should be username and api key separated by a comma\n",
    "# with open(\"../auth/chart_studio_api_key.txt\") as f:\n",
    "#     username, api_key = f.read().strip().split(\",\")\n",
    "#     # strip leading whitespace\n",
    "#     username = username.strip()\n",
    "#     api_key = api_key.strip()\n",
    "\n",
    "# chart_studio.tools.set_credentials_file(username=username, api_key=api_key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = 'ioi'\n",
    "MODEL_NAMES = ['pythia-70m', 'pythia-160m', 'pythia-410m', 'pythia-1.4b', 'pythia-2.8b', 'pythia-6.9b']\n",
    "OUTPUT_DIR = f\"../results/plots/component_metrics/{TASK}/\"\n",
    "\n",
    "# create output directory\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## In-Graph Component Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "component_scores_all_models = dict()\n",
    "heads_all_models = dict()\n",
    "\n",
    "for model_name in MODEL_NAMES:\n",
    "    component_scores_all_models[model_name] = torch.load(f'results/components/{model_name}/components_over_time.pt')\n",
    "    heads_all_models[model_name] = torch.load(f'results/components/{model_name}/heads_over_time.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MODEL_NAMES:\n",
    "    print(f\"Plotting {model_name}\")\n",
    "    plot_nmh_metrics(model_name, component_scores_all_models[model_name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Whole-Model Component Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.component_evaluation import get_attention_to_ioi_token, get_cspa_for_head, compute_copy_score, convert_head_names_to_tuple\n",
    "from utils.cspa_main import prepare_data\n",
    "\n",
    "# def get_copy_score_for_heads(model, head_list, dataset, verbose=False, batch_size=None):\n",
    "    \n",
    "#     return compute_copy_score(model, head_list, dataset, verbose=verbose, neg=False, batch_size=batch_size)\n",
    "\n",
    "# def get_attention_to_ioi_tokens_for_heads(model, head_list, dataset, batch_size=70):\n",
    "#     s1_attn_scores, s2_attn_scores, io_attn_scores = get_attention_to_ioi_token(model, dataset, head_list, batch_size=batch_size)\n",
    "#     return s1_attn_scores, s2_attn_scores, io_attn_scores\n",
    "\n",
    "# def get_copy_suppression_score_for_heads(model, head_list):\n",
    "#     model.cfg.use_split_qkv_input = False\n",
    "#     DATA_TOKS, DATA_STR_TOKS_PARSED, cspa_semantic_dict, indices = prepare_data(model)\n",
    "\n",
    "#     copy_suppression_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))\n",
    "#     for layer, head in head_list:\n",
    "#         if layer > 1:\n",
    "#             copy_suppression_scores[layer, head] = get_cspa_for_head(model, DATA_TOKS, cspa_semantic_dict, layer, head, verbose=False)\n",
    "#     model.cfg.use_split_qkv_input = True\n",
    "#     return copy_suppression_scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_head_scores = {k: dict() for k in MODEL_NAMES}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Copy Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.backup_analysis import load_model\n",
    "from utils.data_utils import generate_data_and_caches\n",
    "from utils.data_processing import get_ckpts\n",
    "from utils.component_evaluation import compute_copy_score\n",
    "\n",
    "TASK = 'ioi'\n",
    "BASE_MODEL = \"pythia-2.8b\"\n",
    "VARIANT = None #\"EleutherAI/pythia-160m-attndropout\"\n",
    "MODEL_SHORTNAME = BASE_MODEL if not VARIANT else VARIANT[11:]\n",
    "DATASET_SIZE = 70\n",
    "CACHE = \"/mnt/hdd-0/circuits-over-time/model_cache/ct\"\n",
    "DEVICE = \"cuda:1\"\n",
    "checkpoints_to_check = get_ckpts(\"sparse\")\n",
    "print(checkpoints_to_check)\n",
    "\n",
    "if BASE_MODEL == \"pythia-6.9b\":\n",
    "    large_model = True\n",
    "else:\n",
    "    large_model = False\n",
    "\n",
    "model = load_model(BASE_MODEL, VARIANT, 143000, CACHE, DEVICE, large_model=large_model)\n",
    "model.tokenizer.add_bos_token = False\n",
    "ioi_dataset, abc_dataset = generate_data_and_caches(model, DATASET_SIZE, verbose=True, prepend_bos=True)\n",
    "\n",
    "# check if file exists\n",
    "if os.path.exists(f\"../results/components/{BASE_MODEL}/early_whole_model_copy_scores.pt\"):\n",
    "    all_head_scores = dict()\n",
    "    all_head_scores[\"copy_score\"] = torch.load(f\"../results/components/{BASE_MODEL}/early_whole_model_copy_scores.pt\")\n",
    "else:\n",
    "    all_head_scores = dict()\n",
    "    all_head_scores[\"copy_score\"] = dict()\n",
    "\n",
    "for ckpt in checkpoints_to_check:\n",
    "    print(f\"Checking {ckpt}\")\n",
    "    if ckpt in all_head_scores[\"copy_score\"]:\n",
    "        continue\n",
    "    model = model = load_model(BASE_MODEL, VARIANT, ckpt, CACHE, DEVICE, large_model=large_model)\n",
    "    head_list = [(i, j) for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]\n",
    "    all_head_scores[\"copy_score\"][ckpt] = compute_copy_score(model, head_list, ioi_dataset, verbose=False, batch_size=10)\n",
    "    torch.save(all_head_scores[\"copy_score\"], f\"../results/components/{BASE_MODEL}/early_whole_model_copy_scores.pt\")\n",
    "\n",
    "    torch.save(all_head_scores[\"copy_score\"], f\"/mnt/hdd-0/circuits-over-time/results/components/{BASE_MODEL}/early_whole_model_copy_scores.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_s = torch.load(f\"../results/components/{BASE_MODEL}/early_whole_model_copy_scores.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_s.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MODEL_NAMES:\n",
    "    file_path = f\"/mnt/hdd-0/circuits-over-time/results/components/{model_name}/early_whole_model_copy_scores.pt\"\n",
    "    if os.path.exists(file_path):\n",
    "        #print(f\"{model_name}\")\n",
    "        copy_scores = torch.load(file_path)\n",
    "        df = plot_head_circuit_scores(\n",
    "            copy_scores, \n",
    "            title= f'Copy Score Across Checkpoints (all {model_name} heads)', \n",
    "            show_legend=False, \n",
    "            disable_title=False,\n",
    "            output_path = OUTPUT_DIR\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "copy_scores[4000].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoints[4000].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Copy Suppression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.result_plotting import (\n",
    "    display_cspa_grids,\n",
    "    load_checkpoints\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_outliers(checkpoint_dict: Dict[int, torch.Tensor], min_value: float, max_value: float) -> Dict[int, torch.Tensor]:\n",
    "    for checkpoint in checkpoint_dict.keys():\n",
    "        tensor = checkpoint_dict[checkpoint]\n",
    "        # Set values outside the range to 0.0\n",
    "        tensor = torch.where(tensor < min_value, torch.tensor(0.0), tensor)\n",
    "        tensor = torch.where(tensor > max_value, torch.tensor(0.0), tensor)\n",
    "        checkpoint_dict[checkpoint] = tensor\n",
    "    return checkpoint_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_directory = '../results/cspa'\n",
    "loaded_checkpoints = load_checkpoints(target_directory)\n",
    "loaded_checkpoints.sort(key=lambda x: x[0])\n",
    "for model_shortname, checkpoints in loaded_checkpoints:\n",
    "    checkpoints = clean_outliers(checkpoints, 0.0, 1.0)\n",
    "\n",
    "    if model_shortname not in MODEL_NAMES:\n",
    "        print(f\"{model_shortname} not in MODEL_NAMES\")\n",
    "        continue\n",
    "    print(f\"Subfolder: {model_shortname}\")\n",
    "\n",
    "    df = plot_head_circuit_scores(\n",
    "        checkpoints, \n",
    "        title= f'CSPA Score Across Checkpoints (all {model_shortname} heads)',\n",
    "        range_y=[0, 1], \n",
    "        show_legend=False, \n",
    "        disable_title=False,\n",
    "        output_path = OUTPUT_DIR,\n",
    "        log_x=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_checkpoint_dict = {k: v for k, v in loaded_checkpoints}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[(ckpt, loaded_checkpoint_dict['pythia-160m'][ckpt][8, 9]) for ckpt in checkpoints.keys()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_checkpoint_dict['pythia-160m'][32000][8, 9]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tertiary Scores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.component_evaluation import load_induction_metrics\n",
    "induction_metrics = load_induction_metrics(MODEL_NAMES[0], \"../results/task_performance_metrics\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_by_checkpoint = dict()\n",
    "scores_by_type = dict()\n",
    "\n",
    "for model_name in MODEL_NAMES[:-1]:\n",
    "    scores_by_checkpoint[model_name] = torch.load(f'../results/components/{model_name}/full_model_components_over_time.pt')\n",
    "    scores_by_type[model_name] = dict()\n",
    "    for type in scores_by_checkpoint[model_name][4000]['tertiary_head_scores'].keys():\n",
    "        print(f\"Processing {model_name} {type}\")\n",
    "        \n",
    "        scores_by_type[model_name][type] = dict()\n",
    "        scores_by_type[model_name][type] = {checkpoint: v['tertiary_head_scores'][type] for checkpoint, v in scores_by_checkpoint[model_name].items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MODEL_NAMES:\n",
    "    df = plot_head_circuit_scores(\n",
    "        scores_by_type[model_name]['induction_scores'], \n",
    "        title= f'Induction Score Across Checkpoints (all {model_name} heads)', \n",
    "        show_legend=False, \n",
    "        disable_title=False,\n",
    "        output_path = OUTPUT_DIR,\n",
    "        log_x=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MODEL_NAMES:\n",
    "    df = plot_head_circuit_scores(\n",
    "        scores_by_type[model_name]['prev_token_scores'], \n",
    "        title= f'Previous Token Score Across Checkpoints (all {model_name} heads)', \n",
    "        show_legend=False, \n",
    "        disable_title=False,\n",
    "        output_path = OUTPUT_DIR,\n",
    "        log_x=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MODEL_NAMES:\n",
    "    df = plot_head_circuit_scores(\n",
    "        scores_by_type[model_name]['duplicate_token_scores'], \n",
    "        title= f'Duplicate Token Detection Score Across Checkpoints (all {model_name} heads)', \n",
    "        show_legend=False, \n",
    "        disable_title=False,\n",
    "        output_path = OUTPUT_DIR,\n",
    "        log_x=True\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'pythia-1.4b'\n",
    "#p = Path('/mnt/ssd-1/XXXX/circuits-over-time/results/cspa')\n",
    "p = Path('/mnt/hdd-0/circuits-over-time/results/components')\n",
    "model_path = p/model\n",
    "\n",
    "data = torch.load(model_path / 'whole_model_cspa.pt')\n",
    "data = clean_outliers(data, 0.0, 1.0)\n",
    "steps = sorted(list(data.keys()))\n",
    "head_scores = torch.stack([data[step].cpu() for step in steps])\n",
    "\n",
    "layers, heads = (x.tolist() for x in torch.where(head_scores.max(dim=0).values >= (head_scores.max() * 0.15)))\n",
    "\n",
    "all_heads = set(zip(layers, heads))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.keys(), all_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[data[k][10, 7] for k in data.keys()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "# Define the paths for folders A and B\n",
    "folder_a = '/mnt/hdd-0/circuits-over-time/results/cspa'\n",
    "folder_b = '/mnt/hdd-0/circuits-over-time/results/components'\n",
    "file_name = 'all_checkpoints.pt'\n",
    "new_file_name = 'whole_model_cspa.pt'\n",
    "\n",
    "# Iterate through all subfolders in folder A\n",
    "for root, dirs, files in os.walk(folder_a):\n",
    "    for subdir in dirs:\n",
    "        subfolder_a = os.path.join(root, subdir)\n",
    "        subfolder_b = os.path.join(folder_b, subdir)\n",
    "\n",
    "        # Check if the specific file exists in the subfolder\n",
    "        source_file = os.path.join(subfolder_a, file_name)\n",
    "        if os.path.exists(source_file):\n",
    "            try:\n",
    "                # Create the corresponding subfolder in folder B if it doesn't exist\n",
    "                os.makedirs(subfolder_b, exist_ok=True)\n",
    "\n",
    "                # Define the new destination file path\n",
    "                destination_file = os.path.join(subfolder_b, new_file_name)\n",
    "\n",
    "                # Copy the file to the same subfolder in folder B with the new name\n",
    "                shutil.copy2(source_file, destination_file)\n",
    "                print(f\"Copied {source_file} to {destination_file}\")\n",
    "            except PermissionError as e:\n",
    "                print(f\"PermissionError: {e}\")\n",
    "            except Exception as e:\n",
    "                print(f\"An error occurred: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
