{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "\n",
    "# Hack to avoid some import problem due to the library being a subfolder\n",
    "try:\n",
    "    sys.path.append(\"third_party/TransformerLens\")\n",
    "    import transformer_lens as lens # Some python problem causes this to throw on the first import\n",
    "except:\n",
    "    pass\n",
    "\n",
    "sys.path.append(\"third_party/TransformerLens\")\n",
    "import transformer_lens as lens # Import TLens from the local copy shipped with this project - It included various bugfixes as well as implementations for Vision-Language (VL) model hooking\n",
    "import torch\n",
    "import torch\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import plotly.express as px\n",
    "from visualization_utils import imshow, line, scatter, multiple_lines\n",
    "from general_utils import get_tokens, topk_2d\n",
    "from analysis_utils import load_model, load_dataset, SUPPORTED_TASKS\n",
    "from modality_alignment_utils import get_image_positions, get_text_sequence_positions\n",
    "import plotly.io as pio\n",
    "from collections import defaultdict, OrderedDict\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "from pprint import pprint\n",
    "from plotly.subplots import make_subplots\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "device = \"cuda\"\n",
    "COLORBLIND_COLORS = ['#0173b2', '#de8f05', '#029e73','#d55e00', '#cc78bc', '#ca9161', '#fbafe4', '#949494', '#ece133', '#56b4e9']\n",
    "\n",
    "\n",
    "DEFAULT_METRIC = \"LD\"\n",
    "MODEL_NAMES = [\"qwen2-7b-vl-instruct\", \"pixtral-12b\", \"gemma-3-12b-it\"]\n",
    "MODEL_PATHS = {\n",
    "    MODEL_NAMES[0]: \"/PATH_TO_MODELS/models--Qwen--Qwen2-VL-7B-Instruct/snapshots/a7a06a1cc11b4514ce9edcde0e3ca1d16e5ff2fc\",\n",
    "    MODEL_NAMES[1]: \"/PATH_TO_MODELS/models--mistral-community--pixtral-12b/snapshots/c2756cbbb9422eba9f6c5c439a214b0392dfc998/\",\n",
    "    MODEL_NAMES[2]: \"/PATH_TO_MODELS/models--google--gemma-3-12b-it/snapshots/96b6f1eccf38110c56df3a15bffe176da04bfd80\"\n",
    "}\n",
    "VISUALIZED_MODEL_NAMES = {\n",
    "    MODEL_NAMES[0]: \"Qwen2-VL-7B\",\n",
    "    MODEL_NAMES[1]: \"Pixtral-12B\",\n",
    "    MODEL_NAMES[2]: \"Gemma-3-12B\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Faithfulness graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show node faithfulness graphs\n",
    "\n",
    "HIGH_FAITH_THRESHOLD = 0.80\n",
    "\n",
    "def get_first_over_threshold(faiths):\n",
    "    over_threshold_indices = (faiths > HIGH_FAITH_THRESHOLD).nonzero().view(-1)\n",
    "    if len(over_threshold_indices) > 0:\n",
    "        return over_threshold_indices[0].item()\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "show_l_faiths = True\n",
    "show_vl_faiths = True\n",
    "show_cross_modality_faiths = False\n",
    "show_interchange_faiths = True\n",
    "\n",
    "\n",
    "l_percentages, vl_percentages = [], []\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        try:\n",
    "            print(model_name, task_name)\n",
    "            faiths_cf = []\n",
    "            line_titles_cf = []\n",
    "\n",
    "            if show_l_faiths:\n",
    "                results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_{DEFAULT_METRIC}_l_node_circuit.pt\"     \n",
    "                percentages, faiths_l_cf, faiths_l_mask = torch.load(results_path, weights_only=True)\n",
    "                l_percentages.append(percentages[get_first_over_threshold(faiths_l_cf.diag())])\n",
    "                print(f'L percentage with high faith: {l_percentages[-1] :.3f}')\n",
    "                faiths_cf.append(faiths_l_cf.diag())\n",
    "                line_titles_cf.append(f\"{DEFAULT_METRIC} L-Discover L-Eval CF\")\n",
    "                \n",
    "            if show_vl_faiths:\n",
    "                results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_{DEFAULT_METRIC}_vl_node_circuit.pt\"     \n",
    "                percentages, faiths_vl_cf, faiths_vl_mask = torch.load(results_path, weights_only=True)\n",
    "                vl_percentages.append(percentages[get_first_over_threshold(faiths_vl_cf.diag())])\n",
    "                print(f'VL percentage with high faith: {vl_percentages[-1] :.3f}')\n",
    "                faiths_cf.append(faiths_vl_cf.diag())\n",
    "                line_titles_cf.append(f\"{DEFAULT_METRIC} VL-Discover VL-Eval CF\")\n",
    "\n",
    "            if show_cross_modality_faiths:\n",
    "                results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_modal_{DEFAULT_METRIC}.pt\"\n",
    "                percentages_cross, faith_discover_l_eval_vl, faith_discover_vl_eval_l = torch.load(results_path, weights_only=True)\n",
    "                assert percentages == percentages_cross\n",
    "                faiths_cf.append(faith_discover_l_eval_vl)\n",
    "                line_titles_cf.append(f\"{DEFAULT_METRIC} L-Discover VL-Eval CF\")\n",
    "                faiths_cf.append(faith_discover_vl_eval_l)\n",
    "                line_titles_cf.append(f\"{DEFAULT_METRIC} VL-Discover L-Eval CF\")\n",
    "\n",
    "            if show_interchange_faiths:\n",
    "                results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{DEFAULT_METRIC}.pt\"\n",
    "                results_dict = torch.load(results_path, weights_only=True)\n",
    "                print(f\"Interchange results: \")\n",
    "                avg = lambda k1, k2: (results_dict[k1] + results_dict[k2]) / 2\n",
    "                print(f\"D Interchange | Random Baseline | Clean result: {avg('DL_QV_LV', 'DV_QL_LL') :.3f} | {avg('DR_QV_LV', 'DR_QL_LL') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}\")\n",
    "                print(f\"Q Interchange | Random Baseline | Clean result: {avg('DV_QL_LV', 'DL_QV_LL') :.3f} | {avg('DV_QR_LV', 'DL_QR_LL') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}\")\n",
    "                print(f\"L Interchange | Random Baseline | Clean result: {avg('DV_QV_LL', 'DL_QL_LV') :.3f} | {avg('DV_QV_LR', 'DL_QL_LR') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}\")\n",
    "\n",
    "            fig = multiple_lines(\n",
    "                x=percentages,\n",
    "                y=faiths_cf,\n",
    "                line_titles=line_titles_cf,\n",
    "                title=f\"Faithfulness vs Ablation percentage<br>({task_name}, {model_name})\",\n",
    "                width=500,\n",
    "                show_fig=False\n",
    "            )\n",
    "            fig.update_xaxes(title_text=\"Nodes included in circuit (percent)\")\n",
    "            fig.update_yaxes(title_text=\"Faithfulness\")\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"\\n\\nError in {task_name} - {model_name}\\n\\n\")\n",
    "            print(e)\n",
    "        \n",
    "print(\"Average L circuit percentage: \", np.mean(l_percentages))\n",
    "print(\"Average VL circuit percentage: \", np.mean(vl_percentages))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Intersection results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unify_l_vl(result_dict, key_suffix):\n",
    "    return round((result_dict[f\"l_{key_suffix}\"] + result_dict[f\"vl_{key_suffix}\"]) / 2, 3)\n",
    "\n",
    "for task_name in SUPPORTED_TASKS:\n",
    "    for model_name in MODEL_NAMES:\n",
    "        try:\n",
    "            result_dict = torch.load(f\"./data/{task_name}/results/{model_name}/intersection_results.pt\", weights_only=True)\n",
    "            pprint(f\"Model: {model_name}; Task: {task_name}\")\n",
    "\n",
    "            print(\"Averaged between L and VL SIDES\")\n",
    "            print(f\"Unified by pos (MLP,ATTN): {unify_l_vl(result_dict, 'mlp_iou'), unify_l_vl(result_dict, 'head_iou')}; Random Baseline (MLP,ATTN): {unify_l_vl(result_dict, 'mlp_baseline'), unify_l_vl(result_dict, 'head_baseline')}\")\n",
    "            print(f\"D pos: {unify_l_vl(result_dict, 'D_neurons_iou'), unify_l_vl(result_dict, 'D_head_iou')}\")\n",
    "            print(f\"Q pos: {unify_l_vl(result_dict, 'Q_neurons_iou'), unify_l_vl(result_dict, 'Q_head_iou')}\")\n",
    "            print(f\"L pos: {unify_l_vl(result_dict, 'G_neurons_iou'), unify_l_vl(result_dict, 'G_head_iou')}\")\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Model: {model_name}; Task: {task_name}, No intersection results found\")\n",
    "            # print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Backpatching results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Present full back-patching heatmaps\n",
    "\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f\"./data/{task_name}/results/{model_name}/backpatching_results.pt\", weights_only=False)\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 2 or cfg[0] is not True:\n",
    "                continue\n",
    "\n",
    "            print(model_name, task_name, cfg)\n",
    "            results = backpatching_results[cfg][0] - backpatching_results['clean_accs'][0]\n",
    "            results.clamp_(min=-0.1, max=0.1)\n",
    "            fig = px.imshow(\n",
    "                results,\n",
    "                title=f\"Back-patching results<br>({task_name}, {model_name})\",\n",
    "                color_continuous_scale=\"RdBu\",\n",
    "                color_continuous_midpoint=0,\n",
    "                width=600\n",
    "            )   \n",
    "            fig.update_xaxes(\n",
    "                tickvals=list(range(backpatching_results[cfg][0].shape[1])),\n",
    "                ticktext=[dst_layer_range[i] for i in range(backpatching_results[cfg][0].shape[1])],\n",
    "            )\n",
    "            fig.update_yaxes(\n",
    "                tickvals=list(range(backpatching_results[cfg][0].shape[0])),\n",
    "                ticktext=[src_layer_range[i] for i in range(backpatching_results[cfg][0].shape[0])],\n",
    "            )\n",
    "            fig.show()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Observe full backpatching results for individual (task, model) pair\n",
    "\n",
    "task_name = SUPPORTED_TASKS[1]\n",
    "model_name = MODEL_NAMES[0]\n",
    "\n",
    "control_results = {}\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        print(model_name, task_name)\n",
    "\n",
    "        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f\"./data/{task_name}/results/{model_name}/backpatching_results.pt\", weights_only=False)\n",
    "        print(backpatching_results['clean_accs'])\n",
    "\n",
    "        # Get the top-10 backpatching results across all settings and src->dst options\n",
    "        k = 10\n",
    "        top_results = []\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 2:\n",
    "                continue\n",
    "            if cfg[0] is not True:\n",
    "                # Ignore setting where data positions are processed post back-patching\n",
    "                continue\n",
    "            (top_h_indices, top_w_indices), top_accs = topk_2d(backpatching_results[cfg][0], k)\n",
    "            for top_h_index, top_w_index, top_acc in zip(top_h_indices, top_w_indices, top_accs):\n",
    "                top_results.append(cfg + (src_layer_range[top_h_index], dst_layer_range[top_w_index], top_acc))\n",
    "\n",
    "        top_results = sorted(top_results, key=lambda x: x[-1], reverse=True)[:k]\n",
    "        pprint([f\"Repeat Processing={r[0]}; Layer window size={r[1]}; Layers={r[2]}->{r[3]}; Acc={r[-1].item() :.3f}\" for r in top_results])\n",
    "\n",
    "\n",
    "        # Comparing to control (L->L backpatching; Should (hopefully) lead to a smaller improvement)\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 2:\n",
    "                continue\n",
    "            if cfg[0] is not True:\n",
    "                # Ignore setting where data positions are processed post back-patching\n",
    "                continue\n",
    "\n",
    "            # Remove non-valid settings (i.e. dst >= src) and subtract clean accuracies\n",
    "            backpatching_diffs = backpatching_results[cfg][0].view(-1)[(backpatching_results[cfg][0].view(-1) > 0)] - backpatching_results[\"clean_accs\"][0]\n",
    "            control_backpatching_diffs = backpatching_results[cfg][1].view(-1)[(backpatching_results[cfg][1].view(-1) > 0)] - backpatching_results[\"clean_accs\"][1]\n",
    "\n",
    "            bp_better_than_control_percent = (backpatching_diffs >= control_backpatching_diffs).float().mean()\n",
    "            control_results[(model_name, task_name, cfg[1])] = bp_better_than_control_percent\n",
    "            print(f\"{cfg}: V Backpatching gets stronger boost in {bp_better_than_control_percent :.3f}% of the cases\")\n",
    "            \n",
    "            best_backpatching_increase = backpatching_diffs.max()\n",
    "            best_control_increase = control_backpatching_diffs.max()\n",
    "            print(task_name, model_name, cfg, best_backpatching_increase, best_control_increase)\n",
    "\n",
    "print(control_results)\n",
    "print(f'BP is higher than control (without maxing for best model-task setting) in {np.mean(list(control_results.values())) :.3f} of the cases')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Present BP best results across models and tasks\n",
    "\n",
    "print('Model\\t\\tTask name\\t\\tV Acc\\t\\tL Acc\\t\\tBackpatching-induced Acc')\n",
    "relative_diffs = []\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f\"./data/{task_name}/results/{model_name}/backpatching_results.pt\", weights_only=False)\n",
    "        k = 1\n",
    "        top_results = []\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 2:\n",
    "                continue\n",
    "            (top_h_indices, top_w_indices), top_accs = topk_2d(backpatching_results[cfg][0], k)\n",
    "            for top_h_index, top_w_index, top_acc in zip(top_h_indices, top_w_indices, top_accs):\n",
    "                top_results.append(cfg + (src_layer_range[top_h_index], dst_layer_range[top_w_index], top_acc))\n",
    "        top_results = sorted(top_results, key=lambda x: x[-1], reverse=True)[:k]\n",
    "        bp_best_acc = top_results[0][-1].item()\n",
    "        clean_v, clean_l = backpatching_results['clean_accs']\n",
    "        relative_diff = (bp_best_acc - clean_v) / (clean_l - clean_v)\n",
    "        if 0 < relative_diff <= 1.0:\n",
    "            relative_diffs.append(relative_diff)\n",
    "        print(f\"{model_name[:4]}\\t\\t{task_name[:10]}\\t\\t{clean_v :.3f}\\t\\t{clean_l :.3f}\\t\\t{bp_best_acc :.3f} ({relative_diff :.3f})\")\n",
    "\n",
    "print(\"Average relative diff: \", np.mean(relative_diffs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Circuit Discovery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw the attribution scores per layer per position (summed across components)\n",
    "\n",
    "fs = 28 # Font size\n",
    "\n",
    "task_name = SUPPORTED_TASKS[0]\n",
    "model_name = MODEL_NAMES[0]\n",
    "\n",
    "for modality in ['l', 'vl']:\n",
    "    scores = torch.load(f'./data/{task_name}/results/{model_name}/node_scores/nap_ig_{modality}_ig=5_metric=LD.pt', weights_only=True)\n",
    "    scores = {k: v.abs() for (k, v) in scores.items()}\n",
    "    n_layers = len([k for k in scores.keys() if 'mlp.hook_post' in k])\n",
    "    seq_len = scores[list(scores.keys())[0]].shape[0]\n",
    "\n",
    "    summed_scores_per_layer_per_pos = torch.zeros(n_layers, seq_len)\n",
    "    for layer in range(n_layers):\n",
    "        mlp_hook_key = f'blocks.{layer}.mlp.hook_post'\n",
    "        attn_hook_key = f'blocks.{layer}.attn.hook_z'\n",
    "        for pos in range(seq_len):\n",
    "            summed_scores_per_layer_per_pos[layer, pos] = scores[mlp_hook_key][pos].sum() + scores[attn_hook_key][pos].sum()\n",
    "\n",
    "    if modality == 'l':\n",
    "        start_of_data = get_text_sequence_positions(model_name, task_name)[0]\n",
    "        summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]\n",
    "        fig = px.imshow(summed_scores_per_layer_per_pos, color_continuous_scale=\"Blues\")\n",
    "        \n",
    "        tickvals = [20 - start_of_data, 30 - start_of_data, seq_len - 1 - start_of_data]\n",
    "        ticktext = ['Data (Text)', 'Query', 'Generation']\n",
    "        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, title=dict(text='Position', font=dict(size=fs)), tickfont=dict(size=fs - 4))\n",
    "        fig.update_yaxes(tickvals=list(range(5, n_layers, 5)), title=dict(text='Layer', font=dict(size=fs)), tickfont=dict(size=fs - 4))\n",
    "        fig.update_layout(\n",
    "            title=dict(text=\"Textual Task Patching Effects\", font=dict(size=fs), x=0.5, y=0.99),\n",
    "            width=530,\n",
    "            xaxis_tickangle=0,\n",
    "            margin=dict(l=0, r=0, t=30, b=0),  # Remove margins\n",
    "            coloraxis_showscale=False  # Hide the colorbar\n",
    "        )\n",
    "    else:\n",
    "        start_of_data = get_image_positions(model_name, task_name)[0]\n",
    "        summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]\n",
    "        fig = px.imshow(summed_scores_per_layer_per_pos, color_continuous_scale=\"Blues\")\n",
    "\n",
    "        tickvals = [70 - start_of_data, 100 - start_of_data, seq_len - 1 - start_of_data]\n",
    "        ticktext = ['Data (Image)', 'Query', 'Generation']\n",
    "        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, title=dict(text='Position', font=dict(size=fs)), tickfont=dict(size=fs - 4))\n",
    "        fig.update_yaxes(tickvals=list(range(5, n_layers, 5)), title=dict(text='Layer', font=dict(size=fs)), tickfont=dict(size=fs - 4))\n",
    "        fig.update_layout(\n",
    "            title=dict(text=\"Visual Task Patching Effects\", font=dict(size=fs), x=0.5, y=0.99),\n",
    "            width=1200,\n",
    "            xaxis_tickangle=0,\n",
    "            coloraxis_colorbar=dict(\n",
    "                thickness=25,  # Adjust the thickness of the color bar\n",
    "                len=1.05       # Adjust the length of the color bar\n",
    "            ),\n",
    "            margin=dict(l=0, r=0, t=30, b=0),  # Remove margins\n",
    "            coloraxis_showscale=True\n",
    "        )\n",
    "    fig.show()\n",
    "\n",
    "    # pio.write_image(fig, f\"./figures/{model_name}_{task_name}_attr_scores_per_layer_per_pos_{modality}.png\")\n",
    "    # pio.write_image(fig, f\"./figures/{model_name}_{task_name}_attr_scores_per_layer_per_pos_{modality}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 28  # Font size\n",
    "\n",
    "task_name = SUPPORTED_TASKS[0]\n",
    "\n",
    "for task_name in SUPPORTED_TASKS:\n",
    "    for model_name in MODEL_NAMES:\n",
    "        fig = make_subplots(\n",
    "            rows=1, cols=2,\n",
    "            column_widths=[0.4, 0.6],\n",
    "            subplot_titles=[\"Textual Task Patching Effects\", \"Visual Task Patching Effects\"],\n",
    "        )\n",
    "\n",
    "        for col, modality in enumerate(['l', 'vl'], start=1):\n",
    "            scores = torch.load(f'./data/{task_name}/results/{model_name}/node_scores/nap_ig_{modality}_ig=5_metric=LD.pt', weights_only=True)\n",
    "            scores = {k: v.abs() for (k, v) in scores.items()}\n",
    "            n_layers = len([k for k in scores.keys() if 'mlp.hook_post' in k])\n",
    "            seq_len = scores[list(scores.keys())[0]].shape[0]\n",
    "\n",
    "            summed_scores_per_layer_per_pos = torch.zeros(n_layers, seq_len)\n",
    "            for layer in range(n_layers):\n",
    "                mlp_hook_key = f'blocks.{layer}.mlp.hook_post'\n",
    "                attn_hook_key = f'blocks.{layer}.attn.hook_z'\n",
    "                for pos in range(seq_len):\n",
    "                    summed_scores_per_layer_per_pos[layer, pos] = scores[mlp_hook_key][pos].sum() + scores[attn_hook_key][pos].sum()\n",
    "\n",
    "            if modality == 'l':\n",
    "                start_of_data, end_of_data = get_text_sequence_positions(model_name, task_name)\n",
    "                summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]\n",
    "                tickvals = [end_of_data]\n",
    "                ticktext = ['Q']\n",
    "                coloraxis = \"coloraxis1\"\n",
    "            else:\n",
    "                start_of_data, end_of_data = get_image_positions(model_name, task_name)\n",
    "                summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]\n",
    "                tickvals = [end_of_data]\n",
    "                ticktext = ['Q']\n",
    "                coloraxis = \"coloraxis2\"\n",
    "            fig.update_annotations(font_size=fs)\n",
    "            fig.add_trace(\n",
    "                go.Heatmap(\n",
    "                    z=summed_scores_per_layer_per_pos.numpy(),\n",
    "                    coloraxis=coloraxis,\n",
    "                ),\n",
    "                row=1, col=col\n",
    "            )\n",
    "\n",
    "            fig.update_xaxes(\n",
    "                tickvals=tickvals,\n",
    "                ticktext=ticktext,\n",
    "                tickfont=dict(size=fs - 4),\n",
    "                row=1, col=col\n",
    "            )\n",
    "            fig.update_yaxes(\n",
    "                tickvals=list(range(5, n_layers, 5)),\n",
    "                title=dict(text=\"Layer\", font=dict(size=fs)),\n",
    "                tickfont=dict(size=fs - 4),\n",
    "                row=1, col=col\n",
    "            )\n",
    "\n",
    "        fig.update_layout(\n",
    "            coloraxis1=dict(colorscale=\"Blues\", showscale=False),\n",
    "            coloraxis2=dict(colorscale=\"Blues\", colorbar=dict(thickness=25, len=1.05)),\n",
    "            width=1600,\n",
    "            height=600,\n",
    "            margin=dict(l=0, r=0, t=50, b=0),\n",
    "        )\n",
    "\n",
    "        print(model_name, task_name)\n",
    "        fig.show()\n",
    "        pio.write_image(fig, f\"./figures/appendix_heatmaps/{model_name}_{task_name}.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Faithfulness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fs = 20  # Font size\n",
    "metric = DEFAULT_METRIC\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=1, cols=len(MODEL_NAMES),\n",
    "    shared_yaxes=True,\n",
    "    subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES],\n",
    ")\n",
    "\n",
    "for col, model_name in enumerate(MODEL_NAMES, start=1):\n",
    "    faiths_cf = []\n",
    "    line_titles_cf = []\n",
    "\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_{metric}_l_node_circuit.pt\"\n",
    "        percentages, faiths_l_cf, faiths_l_mask = torch.load(results_path, weights_only=True)\n",
    "        faiths_cf.append(faiths_l_cf.diag()[1:])\n",
    "        line_titles_cf.append(f\"{task_name.replace('_', ' ').capitalize()}\")\n",
    "\n",
    "        results_path = f\"./data/{task_name}/results/{model_name}/faithfulness_{metric}_vl_node_circuit.pt\"\n",
    "        percentages, faiths_vl_cf, faiths_vl_mask = torch.load(results_path, weights_only=True)\n",
    "        faiths_cf.append(faiths_vl_cf.diag()[1:])\n",
    "        line_titles_cf.append(\"\")  # Empty title for VL faithfulness\n",
    "\n",
    "    for i, (faith, title) in enumerate(zip(faiths_cf, line_titles_cf)):\n",
    "        line_style = 'solid' if title != '' else 'dot'\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x=percentages[1:], y=faith.numpy(),\n",
    "                mode='lines', \n",
    "                line=dict(dash=line_style, color=COLORBLIND_COLORS[i // 2])\n",
    "            ),\n",
    "            row=1, col=col\n",
    "        )\n",
    "\n",
    "fig.update_layout(\n",
    "    width=1100,\n",
    "    height=300,\n",
    "    showlegend=False,\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        yanchor=\"bottom\",\n",
    "        y=-0.9,\n",
    "        xanchor=\"center\",\n",
    "        x=0.5\n",
    "    ),\n",
    "    margin=dict(l=10, r=10, t=30, b=30),  # Adjust margins\n",
    "    xaxis=dict(domain=[0.0, 0.32]),  # Reduce spacing between subplots\n",
    "    xaxis2=dict(domain=[0.34, 0.65]),\n",
    "    xaxis3=dict(domain=[0.67, 1.0]),\n",
    "    annotations=[\n",
    "        dict(\n",
    "            font=dict(size=fs),  # Increase font size for subplot titles\n",
    "            showarrow=False,\n",
    "            text=annotation['text'],\n",
    "            x=annotation['x'],\n",
    "            xanchor='center',\n",
    "            xref=annotation['xref'],\n",
    "            y=annotation['y'],\n",
    "            yanchor=annotation['yanchor'],\n",
    "            yref=annotation['yref']\n",
    "        )\n",
    "        if 'text' in annotation else annotation\n",
    "        for annotation in fig['layout']['annotations']\n",
    "    ],\n",
    ")\n",
    "\n",
    "for col in range(1, len(MODEL_NAMES) + 1):\n",
    "    fig.update_xaxes(title=dict(text=\"Circuit size (% Components)\", font=dict(size=fs)), type=\"log\", row=1, col=col, tickvals=[0.01, 0.1, 1], tickfont=dict(size=fs - 4))\n",
    "    if col == 1:\n",
    "        fig.update_yaxes(title=dict(text=\"Faithfulness\", font=dict(size=fs)), row=1, col=col, tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], tickfont=dict(size=fs - 4))\n",
    "        fig.add_hline(y=0.8, line_color=\"black\", line_dash=\"dot\", annotation_font=dict(size=fs-7), annotation_text=\"High faithfulness<br>threshold\", annotation_position=\"top left\", row=1, col=col)\n",
    "\n",
    "\n",
    "    else:\n",
    "        fig.update_yaxes(showticklabels=False, row=1, col=col)\n",
    "        fig.add_hline(y=0.8, line_color=\"black\", line_dash=\"dot\", annotation_text=\"\", annotation_position=\"top left\", row=1, col=col)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/faithfulness-all-models-and-tasks.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Intersections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unify_l_vl(result_dict, key_suffix):\n",
    "    return round((result_dict[f\"l_{key_suffix}\"] + result_dict[f\"vl_{key_suffix}\"]) / 2, 3)\n",
    "\n",
    "\n",
    "SPLIT_DQG = True\n",
    "NO_POS_IN_DATA = True\n",
    "nps = \"_no_pos\" if NO_POS_IN_DATA else \"\" # nps == no pos suffix\n",
    "\n",
    "W = 0.4\n",
    "x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot\n",
    "\n",
    "\n",
    "fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[m.capitalize() for m in MODEL_NAMES])\n",
    "\n",
    "for model_idx, model_name in enumerate(MODEL_NAMES):\n",
    "    task_intersections, baseline_intersections = [], []\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        # Simulate loading data - replace with your actual loading\n",
    "        intersections_dict = torch.load(f\"./data/{task_name}/results/{model_name}/intersection_results.pt\", weights_only=True)\n",
    "        D_iou = (unify_l_vl(intersections_dict, f'D_neurons_iou{nps}') + unify_l_vl(intersections_dict, f'D_head_iou{nps}')) / 2\n",
    "        Q_iou = (unify_l_vl(intersections_dict, 'Q_neurons_iou') + unify_l_vl(intersections_dict, 'Q_head_iou')) / 2\n",
    "        G_iou = (unify_l_vl(intersections_dict, 'G_neurons_iou') + unify_l_vl(intersections_dict, 'G_head_iou')) / 2\n",
    "\n",
    "        if SPLIT_DQG:\n",
    "            task_intersections.append((D_iou, Q_iou, G_iou))\n",
    "        else:\n",
    "            avg_iou = (D_iou + Q_iou + G_iou) / 3\n",
    "            task_intersections.append(avg_iou)\n",
    "\n",
    "        D_baseline = (unify_l_vl(intersections_dict, f'D_neurons_baseline{nps}') + unify_l_vl(intersections_dict, f'D_head_baseline{nps}')) / 2\n",
    "        Q_baseline = (unify_l_vl(intersections_dict, 'Q_neurons_baseline') + unify_l_vl(intersections_dict, 'Q_head_baseline')) / 2\n",
    "        G_baseline = (unify_l_vl(intersections_dict, 'G_neurons_baseline') + unify_l_vl(intersections_dict, 'G_head_baseline')) / 2\n",
    "        if SPLIT_DQG:\n",
    "            baseline_intersections.append((D_baseline, Q_baseline, G_baseline))\n",
    "        else:\n",
    "            avg_baseline = (D_baseline + Q_baseline + G_baseline) / 3\n",
    "            baseline_intersections.append(avg_baseline)\n",
    "\n",
    "    for i, task_name in enumerate(SUPPORTED_TASKS):\n",
    "        presented_task_name = task_name.replace('_', '<br>').capitalize()\n",
    "        if SPLIT_DQG:\n",
    "            for j, dqg in enumerate(['D', 'Q', 'G']):\n",
    "                fig.add_trace(go.Bar(\n",
    "                    x=[f\"{presented_task_name} {dqg} I\"],\n",
    "                    y=[task_intersections[i][j]],\n",
    "                    marker_color='green',\n",
    "                    name=f\"{task_name} {dqg} Intersection\",\n",
    "                    showlegend=False,\n",
    "                    width=W,\n",
    "                ), row=1, col=model_idx + 1)\n",
    "\n",
    "                fig.add_trace(go.Bar(\n",
    "                    x=[f\"{presented_task_name} {dqg} B\"],\n",
    "                    y=[baseline_intersections[i][j]],\n",
    "                    marker_color='red',\n",
    "                    marker_pattern_shape='x',\n",
    "                    name=f\"{task_name} {dqg} Baseline\",\n",
    "                    showlegend=False,\n",
    "                    width=W,\n",
    "                ), row=1, col=model_idx + 1)\n",
    "        else:\n",
    "            x_inter = f\"{presented_task_name}\"\n",
    "            x_base = f\"{presented_task_name} Base\"\n",
    "            x_vals_all[model_idx].extend([x_inter, x_base])\n",
    "\n",
    "            fig.add_trace(go.Bar(\n",
    "                x=[x_inter],\n",
    "                y=[task_intersections[i]],\n",
    "                marker_color='green',\n",
    "                name=f\"{task_name} Intersection\",\n",
    "                showlegend=False,\n",
    "                width=W,\n",
    "            ), row=1, col=model_idx + 1)\n",
    "            fig.add_trace(go.Bar(\n",
    "                x=[x_base],\n",
    "                y=[baseline_intersections[i]],\n",
    "                marker_color='red',\n",
    "                marker_pattern_shape=\"x\",\n",
    "                name=f\"{task_name} Baseline\",\n",
    "                showlegend=False,\n",
    "                width=W,\n",
    "            ), row=1, col=model_idx + 1)\n",
    "\n",
    "# Update y-axis visibility\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    if i > 1:\n",
    "        fig.update_yaxes(showticklabels=False, range=[0, 1.05], row=1, col=i)\n",
    "    else:\n",
    "        fig.update_yaxes(title_text=\"IoU\", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], range=[0, 1.05], row=1, col=i)\n",
    "\n",
    "\n",
    "# Update x-axes: hide \"Base\" bar labels\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    tickvals = x_vals_all[i - 1]\n",
    "    ticktext = [label if \"Base\" not in label else \"\" for label in tickvals]\n",
    "    fig.update_xaxes(\n",
    "        tickvals=tickvals,\n",
    "        ticktext=ticktext,\n",
    "        row=1,\n",
    "        col=i,\n",
    "    )\n",
    "\n",
    "fig.add_hline(y=1.0, line_dash=\"dot\", line_color=\"black\", row=1, col='all')\n",
    "\n",
    "# Update bar width\n",
    "fig.update_traces(width=0.8)\n",
    "\n",
    "# Update layout\n",
    "fig.update_layout(\n",
    "    barmode='group',\n",
    "    width=1200,\n",
    "    height=400,\n",
    ")\n",
    "fig.update_layout(\n",
    "    margin=dict(l=10, r=0, t=20, b=0),\n",
    "    width=1000,\n",
    "    height=400,\n",
    "    xaxis=dict(domain=[0.0, 0.32]),\n",
    "    xaxis2=dict(domain=[0.34, 0.65]),\n",
    "    xaxis3=dict(domain=[0.67, 0.99])\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/intersections_DQG={SPLIT_DQG}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unify_l_vl(result_dict, key_suffix):\n",
    "    return round((result_dict[f\"l_{key_suffix}\"] + result_dict[f\"vl_{key_suffix}\"]) / 2, 3)\n",
    "\n",
    "\n",
    "SPLIT_DQG = True\n",
    "NO_POS_IN_DATA = True\n",
    "nps = \"_no_pos\" if NO_POS_IN_DATA else \"\" # nps == no pos suffix\n",
    "\n",
    "W = 0.4\n",
    "x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot\n",
    "\n",
    "\n",
    "fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[m.capitalize() for m in MODEL_NAMES])\n",
    "\n",
    "for model_idx, model_name in enumerate(MODEL_NAMES):\n",
    "    task_intersections, baseline_intersections = [], []\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        # Simulate loading data - replace with your actual loading\n",
    "        intersections_dict = torch.load(f\"./data/{task_name}/results/{model_name}/intersection_results.pt\", weights_only=True)\n",
    "        D_iou = (unify_l_vl(intersections_dict, f'D_neurons_iou{nps}') + unify_l_vl(intersections_dict, f'D_head_iou{nps}')) / 2\n",
    "        Q_iou = (unify_l_vl(intersections_dict, 'Q_neurons_iou') + unify_l_vl(intersections_dict, 'Q_head_iou')) / 2\n",
    "        G_iou = (unify_l_vl(intersections_dict, 'G_neurons_iou') + unify_l_vl(intersections_dict, 'G_head_iou')) / 2\n",
    "\n",
    "        if SPLIT_DQG:\n",
    "            task_intersections.append((D_iou, Q_iou, G_iou))\n",
    "        else:\n",
    "            avg_iou = (D_iou + Q_iou + G_iou) / 3\n",
    "            task_intersections.append(avg_iou)\n",
    "\n",
    "        D_baseline = (unify_l_vl(intersections_dict, f'D_neurons_baseline{nps}') + unify_l_vl(intersections_dict, f'D_head_baseline{nps}')) / 2\n",
    "        Q_baseline = (unify_l_vl(intersections_dict, 'Q_neurons_baseline') + unify_l_vl(intersections_dict, 'Q_head_baseline')) / 2\n",
    "        G_baseline = (unify_l_vl(intersections_dict, 'G_neurons_baseline') + unify_l_vl(intersections_dict, 'G_head_baseline')) / 2\n",
    "        if SPLIT_DQG:\n",
    "            baseline_intersections.append((D_baseline, Q_baseline, G_baseline))\n",
    "        else:\n",
    "            avg_baseline = (D_baseline + Q_baseline + G_baseline) / 3\n",
    "            baseline_intersections.append(avg_baseline)\n",
    "\n",
    "    normalized_intersections = (torch.tensor(task_intersections) - torch.tensor(baseline_intersections) / 1.0 - torch.tensor(baseline_intersections)).clamp(0.01, 1)\n",
    "    print('Mean across tasks: ', normalized_intersections.mean(dim=0))\n",
    "    print('Mean across positions: ', normalized_intersections.mean(dim=1))\n",
    "    for i, task_name in enumerate(SUPPORTED_TASKS):\n",
    "        presented_task_name = task_name.replace('_', '<br>').capitalize()\n",
    "        if SPLIT_DQG:\n",
    "            x_inter = f\"{presented_task_name}\"\n",
    "            x_vals_all[model_idx].extend([x_inter])\n",
    "            for j, dqg in enumerate(['D', 'Q', 'G']):\n",
    "                normalized_intersection = normalized_intersections[i][j].item()\n",
    "                fig.add_trace(go.Bar(\n",
    "                    x=[x_inter + dqg],\n",
    "                    y=[normalized_intersection],\n",
    "                    marker_color=COLORBLIND_COLORS[j],\n",
    "                    name=f\"{task_name} {dqg} Intersection\",\n",
    "                    showlegend=False,\n",
    "                    width=W,\n",
    "                ), row=1, col=model_idx + 1)\n",
    "        else:\n",
    "            x_inter = f\"{presented_task_name}\"\n",
    "            x_vals_all[model_idx].extend([x_inter])\n",
    "            normalized_intersection = normalized_intersections[i].item()\n",
    "            fig.add_trace(go.Bar(\n",
    "                x=[x_inter],\n",
    "                y=[normalized_intersection],\n",
    "                marker_color='green',\n",
    "                name=f\"{task_name} Intersection\",\n",
    "                showlegend=False,\n",
    "                width=W,\n",
    "            ), row=1, col=model_idx + 1)\n",
    "\n",
    "# Update y-axis visibility\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    fig.update_yaxes(range=[0, 1.05], row=1, col=i)\n",
    "    if i > 1:\n",
    "        fig.update_yaxes(showticklabels=False, tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)\n",
    "    else:\n",
    "        fig.update_yaxes(title_text=\"Normalized IoU\", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)\n",
    "\n",
    "\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    tickvals = x_vals_all[i - 1]\n",
    "    ticktext = [label for label in tickvals]\n",
    "    fig.update_xaxes(\n",
    "        tickvals=tickvals,\n",
    "        ticktext=ticktext,\n",
    "        row=1,\n",
    "        col=i,\n",
    "    )\n",
    "\n",
    "fig.add_hline(y=1.0, line_dash=\"dot\", line_color=\"black\", row=1, col='all')\n",
    "\n",
    "# Update bar width\n",
    "fig.update_traces(width=0.8)\n",
    "\n",
    "# Update layout\n",
    "fig.update_layout(\n",
    "    barmode='group',\n",
    "    margin=dict(l=10, r=0, t=20, b=5),\n",
    "    width=1000,\n",
    "    height=200,\n",
    "    xaxis=dict(domain=[0.0, 0.32]),\n",
    "    xaxis2=dict(domain=[0.34, 0.65]),\n",
    "    xaxis3=dict(domain=[0.67, 0.99]),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/intersections_normalized_DQG={SPLIT_DQG}.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Interchange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "metric = DEFAULT_METRIC\n",
    "model_name = MODEL_NAMES[0]\n",
    "BASELINE_KEY, VALUE_KEY, SKYLINE_KEY = 'Random Components (Baseline)', 'Modality Switch Faithfulness', 'Clean Circuit Faithfulness'\n",
    "avg = lambda k1, k2: ((results_dict[k1] + results_dict[k2]) / 2).item()\n",
    "show_splits = \"DQG\"\n",
    "\n",
    "tasks = []\n",
    "values = []\n",
    "types = []\n",
    "if \"D\" in show_splits:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        tasks += [task_name.replace('_', ' ').capitalize() + \" D\"] * 3\n",
    "        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]\n",
    "        results_dict = torch.load(f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt\", weights_only=True)\n",
    "        values += [avg('DR_QV_LV', 'DR_QL_LL'), avg('DL_QV_LV', 'DV_QL_LL'), avg('DV_QV_LV', 'DL_QL_LL')]\n",
    "\n",
    "if \"Q\" in show_splits:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        tasks += [task_name.replace('_', ' ').capitalize() + \" Q\"] * 3 # 3 for skyline, values, baseline\n",
    "        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]\n",
    "        results_dict = torch.load(f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt\", weights_only=True)\n",
    "        values += [avg('DV_QR_LV', 'DL_QR_LL'), avg('DV_QL_LV', 'DL_QV_LL'), avg('DV_QV_LV', 'DL_QL_LL')]\n",
    "\n",
    "if \"G\" in show_splits:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        tasks += [task_name.replace('_', ' ').capitalize() + \" G\"] * 3 # 3 for skyline, values, baseline\n",
    "        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]\n",
    "        results_dict = torch.load(f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt\", weights_only=True)\n",
    "        values += [avg('DV_QV_LR', 'DL_QL_LR'), avg('DV_QV_LL', 'DL_QL_LV'), avg('DV_QV_LV', 'DL_QL_LL')]\n",
    "\n",
    "\n",
    "data = pd.DataFrame({\n",
    "    'Task': tasks,\n",
    "    'Value': values,\n",
    "    'Type': types\n",
    "})\n",
    "\n",
    "# Define colors\n",
    "colors = {BASELINE_KEY: 'lightgray', VALUE_KEY: 'steelblue', SKYLINE_KEY: 'lightblue'}\n",
    "\n",
    "# Define patterns for baselines\n",
    "patterns = {BASELINE_KEY: '//', SKYLINE_KEY: '\\\\\\\\'}\n",
    "\n",
    "# Create the barplot\n",
    "fig, ax = plt.subplots(figsize=(8, 6))\n",
    "\n",
    "sns.barplot(x='Task', y='Value', hue='Type', data=data, palette=colors, dodge=True, ax=ax)\n",
    "\n",
    "\n",
    "# Set y-axis limits\n",
    "ax.set_ylim(0, 1)\n",
    "\n",
    "# Remove duplicate x-axis labels\n",
    "handles, labels = ax.get_legend_handles_labels()\n",
    "new_handles = []\n",
    "new_labels = []\n",
    "seen_labels = set()\n",
    "for handle, label in zip(handles, labels):\n",
    "    if label not in seen_labels:\n",
    "        new_handles.append(handle)\n",
    "        new_labels.append(label)\n",
    "        seen_labels.add(label)\n",
    "ax.legend(new_handles, new_labels, loc='lower center', bbox_to_anchor=(0.5, -0.12), ncol=3, frameon=False, fontsize='medium')\n",
    "\n",
    "ax.set_xlabel('', fontsize='large')\n",
    "ax.set_ylabel('Circuit Faithfulness', fontsize='large')\n",
    "ax.tick_params(axis='x', labelrotation=45)\n",
    "plt.title(f'Faithfulness when patching {\" / \".join(show_splits)} sub-circuits', fontsize='large')\n",
    "# plt.tight_layout()\n",
    "plt.show()  \n",
    "\n",
    "# Save the plot as a PDF\n",
    "fig.savefig(f\"./figures/{MODEL_NAMES[0]}_interchange_faithfulness.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = DEFAULT_METRIC\n",
    "model_name = MODEL_NAMES[0]\n",
    "avg = lambda k1, k2: ((results_dict[k1] + results_dict[k2]) / 2).item()\n",
    "\n",
    "W = 0.4\n",
    "x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot\n",
    "fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES])\n",
    "\n",
    "for model_idx, model_name in enumerate(MODEL_NAMES):\n",
    "    task_intersections, baseline_intersections = [], []\n",
    "\n",
    "    scores = []\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        try:\n",
    "            results_dict = torch.load(f\"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt\", weights_only=True)\n",
    "            d_baseline, d_skyline, d_value = avg('DR_QV_LV', 'DR_QL_LL'), avg('DV_QV_LV', 'DL_QL_LL'), avg('DL_QV_LV', 'DV_QL_LL')\n",
    "            q_baseline, q_skyline, q_value = avg('DV_QR_LV', 'DL_QR_LL'),           avg('DV_QV_LV', 'DL_QL_LL'), avg('DV_QL_LV', 'DL_QV_LL'),\n",
    "            g_baseline, g_skyline, g_value = avg('DV_QV_LR', 'DL_QL_LR'),           avg('DV_QV_LV', 'DL_QL_LL'), avg('DV_QV_LL', 'DL_QL_LV')\n",
    "            scores.append([(d_value - d_baseline) / (d_skyline - d_baseline), (q_value - q_baseline) / (q_skyline - q_baseline), (g_value - g_baseline) / (g_skyline - g_baseline)])\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            scores.append(0)\n",
    "    print(scores)\n",
    "    scores = torch.tensor(scores).clamp(min=0.0)\n",
    "\n",
    "    for i, task_name in enumerate(SUPPORTED_TASKS):\n",
    "        presented_task_name = task_name.replace('_', '<br>').capitalize()\n",
    "        x_inter = f\"{presented_task_name}\"\n",
    "        x_vals_all[model_idx].extend([x_inter])\n",
    "        for j, dqg in enumerate(['D', 'Q', 'G']):\n",
    "            fig.add_trace(go.Bar(\n",
    "                x=[x_inter + dqg],\n",
    "                y=[scores[i][j].item()],\n",
    "                marker_color=COLORBLIND_COLORS[j],\n",
    "                name=f\"{task_name} {dqg}\",\n",
    "                showlegend=False,\n",
    "                width=W,\n",
    "            ), row=1, col=model_idx + 1)\n",
    "        \n",
    "\n",
    "# Update y-axis visibility\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    fig.update_yaxes(range=[0, 1.05], row=1, col=i)\n",
    "    if i > 1:\n",
    "        fig.update_yaxes(showticklabels=False, tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)\n",
    "    else:\n",
    "        fig.update_yaxes(title_text=\"Interchange Faithfulness\", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)\n",
    "\n",
    "\n",
    "for i in range(1, len(MODEL_NAMES) + 1):\n",
    "    tickvals = x_vals_all[i - 1]\n",
    "    ticktext = [label for label in tickvals]\n",
    "    fig.update_xaxes(\n",
    "        tickvals=tickvals,\n",
    "        ticktext=ticktext,\n",
    "        row=1,\n",
    "        col=i,\n",
    "    )\n",
    "\n",
    "fig.add_hline(y=1.0, line_dash=\"dot\", line_color=\"black\", row=1, col='all')\n",
    "\n",
    "# Update bar width\n",
    "fig.update_traces(width=0.8)\n",
    "\n",
    "# Update layout\n",
    "fig.update_layout(\n",
    "    barmode='group',\n",
    "    margin=dict(l=10, r=0, t=20, b=5),\n",
    "    width=1000,\n",
    "    height=200,\n",
    "    xaxis=dict(domain=[0.0, 0.32]),\n",
    "    xaxis2=dict(domain=[0.34, 0.65]),\n",
    "    xaxis3=dict(domain=[0.67, 0.99]),\n",
    ")\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/interchange_faithfulness_normalized.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Backpatching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show pre-backpatching similarities of visual image patches to text tokens from the parallel textual sequences\n",
    "\n",
    "all_model_similarities = {(model_name, task_name): torch.load(f\"./data/{task_name}/results/{model_name}/similarities_of_vl_activations_to_text_seq_tokens_k=0.05_use_unembed=True.pt\", weights_only=True).cpu() for model_name in MODEL_NAMES for task_name in SUPPORTED_TASKS}\n",
    "\n",
    "# Show all tasks of a specific model on one plot\n",
    "model_sims = {model_name: torch.stack([all_model_similarities[(model_name, task_name)] for task_name in SUPPORTED_TASKS]) for model_name in MODEL_NAMES}\n",
    "for model_name in MODEL_NAMES:\n",
    "    fig = go.Figure()\n",
    "    for i, task_name in enumerate(SUPPORTED_TASKS):\n",
    "        fig.add_scatter(x=list(range(model_sims[model_name].shape[1])), y=model_sims[model_name][i].numpy(), mode='lines', name=task_name.replace('_', ' ').capitalize(), line=dict(color=COLORBLIND_COLORS[i]))\n",
    "    fig.update_layout(\n",
    "        title=f'Similarity of image patch activations with text token unembeddings<br>({model_name})',\n",
    "        xaxis_title='Layer',\n",
    "        yaxis_title='Similarity',\n",
    "        width=600,\n",
    "        legend=dict(\n",
    "            orientation=\"h\",\n",
    "            yanchor=\"bottom\",\n",
    "            y=-0.3,\n",
    "            xanchor=\"center\",\n",
    "            x=0.5\n",
    "        )\n",
    "    )\n",
    "    fig.update_yaxes(range=[0, 0.3])\n",
    "    fig.show()\n",
    "\n",
    "    pio.write_image(fig, f\"./figures/{model_name}_similarities_to_text_seq_tokens.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print looped results\n",
    "fs = 24\n",
    "\n",
    "looped_results = defaultdict(lambda: [0] * 10)\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f\"./data/{task_name}/results/{model_name}/backpatching_results.pt\", weights_only=False)\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 3:\n",
    "                # Take only looped results (have 3 keys)\n",
    "                continue\n",
    "            looped_results[(model_name, task_name)][cfg[2]] = backpatching_results[cfg]\n",
    "\n",
    "looped_results = sorted(looped_results.items())\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=1, cols=len(MODEL_NAMES),\n",
    "    subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES],\n",
    "    shared_yaxes=True\n",
    ")\n",
    "\n",
    "task_colors = {task_name: COLORBLIND_COLORS[i] for i, task_name in enumerate(SUPPORTED_TASKS)}\n",
    "\n",
    "for col, model_name in enumerate(MODEL_NAMES, start=1):\n",
    "    for idx, ((looped_model_name, task_name), results) in enumerate(looped_results):\n",
    "        if looped_model_name != model_name:\n",
    "            continue\n",
    "        fig.add_trace(\n",
    "            go.Scatter(\n",
    "                x=list(range(len(results))),\n",
    "                y=results,\n",
    "                mode='lines',\n",
    "                name=f\"{task_name}\",\n",
    "                line=dict(color=task_colors[task_name])\n",
    "            ),\n",
    "            row=1, col=col\n",
    "        )\n",
    "fig.update_layout(\n",
    "    margin=dict(l=40, r=40, t=40, b=40),\n",
    "    width=1000,\n",
    "    height=300,\n",
    "    showlegend=False\n",
    ")\n",
    "\n",
    "for col, model_name in enumerate(MODEL_NAMES, start=1):\n",
    "    fig.update_xaxes(title=dict(text=\"Back-patching iteration\", font=dict(size=fs)), row=1, col=col, tickfont=dict(size=fs-4))\n",
    "    fig.update_yaxes(title=dict(text=\"Accuracy\", font=dict(size=fs)), row=1, col=col, tickfont=dict(size=fs-4))\n",
    "    fig.update_annotations(font_size=fs)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/iterative-backpatching-results.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Control vs Normal backpatching results\n",
    "\n",
    "ordered_keys = []\n",
    "control_results = OrderedDict()\n",
    "for model_name in MODEL_NAMES:\n",
    "    for task_name in SUPPORTED_TASKS:\n",
    "        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f\"./data/{task_name}/results/{model_name}/backpatching_results.pt\", weights_only=False)\n",
    "        # Comparing to control (L->L backpatching; Should (hopefully) lead to a smaller improvement)\n",
    "        for cfg in backpatching_results.keys():\n",
    "            if len(cfg) != 2:\n",
    "                continue\n",
    "            if cfg[0] is not True:\n",
    "                continue\n",
    "\n",
    "            # Remove non-valid settings (i.e. dst >= src) and subtract clean accuracies\n",
    "            backpatching_diffs = backpatching_results[cfg][0].view(-1)[(backpatching_results[cfg][0].view(-1) > 0)] - backpatching_results[\"clean_accs\"][0]\n",
    "            control_backpatching_diffs = backpatching_results[cfg][1].view(-1)[(backpatching_results[cfg][1].view(-1) > 0)] - backpatching_results[\"clean_accs\"][1]\n",
    "\n",
    "            bp_better_than_control_percent = (backpatching_diffs >= control_backpatching_diffs).float().mean()\n",
    "            control_results[(model_name, task_name, cfg[1])] = bp_better_than_control_percent\n",
    "\n",
    "# Sort the keys in control_results\n",
    "\n",
    "# Extract x_labels and y_values\n",
    "x_labels = [f\"{key[0]}-{key[1]}-{key[2]}\" for key in control_results.keys()]\n",
    "y_values = [control_results[key].item() * 100 for key in control_results.keys()]\n",
    "# Assign colors based on task_name\n",
    "colors = [COLORBLIND_COLORS[SUPPORTED_TASKS.index(key[1])] for key in control_results.keys()]\n",
    "# Create the scatter plot\n",
    "fig = go.Figure()\n",
    "# Add scatter points with repeating shapes\n",
    "shapes = ['circle', 'square', 'star']\n",
    "fig.add_trace(go.Scatter(\n",
    "    x=x_labels,\n",
    "    y=y_values,\n",
    "    mode='markers',\n",
    "    marker=dict(color=colors, size=10, symbol=[shapes[i % len(shapes)] for i in range(len(x_labels))]),\n",
    "    name=\"Control Results\"\n",
    "))\n",
    "# Add dotted lines\n",
    "fig.add_vline(x=14.5, line_color=\"black\")\n",
    "fig.add_vline(x=29.5, line_color=\"black\")\n",
    "fig.add_hline(y=50, line_dash=\"dot\", line_color=\"gray\", annotation_text=\"Random\", annotation_position=\"top left\", annotation_font=dict(size=16))\n",
    "fig.add_hline(y=0, line_dash=\"dot\", line_color=\"black\", annotation_text=\"Control Advantage\", annotation_position=\"top left\", annotation_font=dict(size=16))\n",
    "fig.add_hline(y=100, line_dash=\"dot\", line_color=\"green\", annotation_text=\"BP Advantage\", annotation_position=\"top left\", annotation_font=dict(size=16))\n",
    "# Update layout\n",
    "\n",
    "fig.update_xaxes(tickvals=[], ticktext=[])\n",
    "fig.update_layout(\n",
    "    yaxis=dict(title=\"Back-patching<br>beats control<br>percentage\", range=[-10, 110]),\n",
    "    width=1000,\n",
    "    height=200,\n",
    "    showlegend=False,\n",
    "    margin=dict(l=0, r=5, t=5, b=5),\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, f\"./figures/control-vs-backpatching-results.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "vlm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
