{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27bb8e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a34393",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d58ed106",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf89b325",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from nnsight import LanguageModel\n",
    "\n",
    "# lm = LanguageModel(\n",
    "#     model_key,\n",
    "#     device_map=\"auto\",\n",
    "#     dispatch=True,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3a80e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32269e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, CountingTask, get_counterfactual_samples_interface\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = CountingTask\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 2\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)\n",
    "\n",
    "test_sample = select_task.get_random_sample(\n",
    "    mt=mt,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    n_distractors=N_DISTRACTORS,\n",
    "    option_style=OPTION_STYLE,\n",
    ")\n",
    "print(test_sample.prompt(), \">>\", mt.tokenizer.decode(test_sample.ans_token_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "388f1913",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "counterfact_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "kwargs = {}\n",
    "if TASK_CLS == CountingTask:\n",
    "    kwargs[\"clean_n_options\"] = N_DISTRACTORS + 1\n",
    "    kwargs[\"patch_n_options\"] = N_DISTRACTORS + 1\n",
    "else:\n",
    "    kwargs[\"clean_n_distractors\"] = N_DISTRACTORS\n",
    "    kwargs[\"patch_n_distractors\"] = N_DISTRACTORS\n",
    "    kwargs[\"clean_prompt_template_idx\"] = prompt_template_idx\n",
    "    kwargs[\"patch_prompt_template_idx\"] = prompt_template_idx\n",
    "\n",
    "patch_sample, clean_sample = counterfact_sampler(\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=OPTION_STYLE,\n",
    "    distinct_options=True,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    **kwargs,\n",
    ")\n",
    "\n",
    "# patch_sample.default_option_style = \"single_line\"\n",
    "# clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "print(\"=\" * 80)\n",
    "\n",
    "print(patch_sample.prompt(), \">>\", mt.tokenizer.decode(patch_sample.ans_token_id))\n",
    "print(clean_sample.prompt(), \">>\", mt.tokenizer.decode(clean_sample.ans_token_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff7d6b06",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "\n",
    "clean_tokenized = prepare_input(tokenizer=mt, prompts=clean_sample.prompt())\n",
    "print(mt.tokenizer.decode(clean_tokenized.input_ids[0], skip_special_tokens=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6899bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "from src.selection.utils import get_first_token_id\n",
    "from src.functional import interpret_logits, PatchSpec\n",
    "from itertools import product\n",
    "from src.utils.typing import TokenizerOutput, ArrayLike\n",
    "from typing import Optional, Union\n",
    "from src.functional import get_module_nnsight, untuple, get_hs, predict_next_token\n",
    "from src.selection.data import SelectionSample\n",
    "import random\n",
    "from src.selection.functional import find_quesmark_pos\n",
    "from src.selection.data import get_options_for_answer\n",
    "\n",
    "def layer_wise_patching(\n",
    "    mt: ModelandTokenizer,\n",
    "    patch_sample: SelectionSample,\n",
    "    clean_sample: SelectionSample,\n",
    "    map_indices: list[int] = {-2: -2, -1: -1},\n",
    "    consider_ques_pos: bool = True,\n",
    "):\n",
    "    clean_tokenized = prepare_input(\n",
    "        prompts=clean_sample.prompt(), tokenizer=mt, return_offsets_mapping=True\n",
    "    )\n",
    "    patch_tokenized = prepare_input(\n",
    "        prompts=patch_sample.prompt(), tokenizer=mt, return_offsets_mapping=True\n",
    "    )\n",
    "    clean_offset_mapping = clean_tokenized.pop(\"offset_mapping\")[0]\n",
    "    patch_offset_mapping = patch_tokenized.pop(\"offset_mapping\")[0]\n",
    "    if consider_ques_pos:\n",
    "        clean_ques_pos = find_quesmark_pos(\n",
    "            prompt=clean_sample.prompt(),\n",
    "            tokenizer=mt.tokenizer,\n",
    "            tokenized=clean_tokenized,\n",
    "            offset_mapping=clean_offset_mapping,\n",
    "        )\n",
    "        patch_ques_pos = find_quesmark_pos(\n",
    "            prompt=patch_sample.prompt(),\n",
    "            tokenizer=mt.tokenizer,\n",
    "            tokenized=patch_tokenized,\n",
    "            offset_mapping=patch_offset_mapping,\n",
    "        )\n",
    "        map_indices[patch_ques_pos] = clean_ques_pos\n",
    "\n",
    "    random_idx = random.choice(\n",
    "        list(\n",
    "            set(list(range(len(clean_sample.options))))\n",
    "            - {\n",
    "                patch_sample.obj_idx,\n",
    "                clean_sample.obj_idx,\n",
    "                clean_sample.metadata[\"track_type_obj_idx\"],\n",
    "            }\n",
    "        )\n",
    "    )\n",
    "\n",
    "    track_tokens = {\n",
    "        \"predicate_target\": clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "        \"clean_ans\": get_first_token_id(clean_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"patch_ans\": get_first_token_id(patch_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"patch_position\": get_first_token_id(\n",
    "            clean_sample.options[patch_sample.obj_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "        \"random_distractor\": get_first_token_id(\n",
    "            clean_sample.options[random_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    ret = {\"track_tokens\": track_tokens}\n",
    "\n",
    "    logit_location = (mt.lm_head_name, -1)\n",
    "    cache_h_from_locations = list(product(mt.layer_names, list(map_indices.keys())))\n",
    "    # patch_locations = []\n",
    "    print(cache_h_from_locations)\n",
    "\n",
    "    patch_hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized,\n",
    "        locations=cache_h_from_locations + [logit_location],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    patch_logits = patch_hs[logit_location]\n",
    "    patch_pred, patch_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=patch_logits,\n",
    "        interested_tokens=track_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"patch_pred={[str(pred) for pred in patch_pred]}\")\n",
    "    logger.debug(f\"patch_track={patch_track}\")\n",
    "    ret[\"patch_pred\"] = patch_pred\n",
    "    ret[\"patch_track\"] = patch_track\n",
    "\n",
    "    clean_hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=clean_tokenized,\n",
    "        locations=[logit_location],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    clean_logits = clean_hs[logit_location]\n",
    "    clean_pred, clean_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=clean_logits,\n",
    "        interested_tokens=track_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "    logger.debug(f\"clean_track={clean_track}\")\n",
    "    ret[\"clean_pred\"] = clean_pred\n",
    "    ret[\"clean_track\"] = clean_track\n",
    "\n",
    "    interestested_tokens = list(track_tokens.values())\n",
    "    option_tokens = [\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in get_options_for_answer(clean_sample)\n",
    "    ]\n",
    "    interestested_tokens = list(set(interestested_tokens) | set(option_tokens))\n",
    "    layer_wise_patching_results = {}\n",
    "    for layer in mt.layer_names:\n",
    "        patch_spec = []\n",
    "        for patch_tok_idx in map_indices.keys():\n",
    "            patch_spec.append(\n",
    "                PatchSpec(\n",
    "                    location=(layer, map_indices[patch_tok_idx]),\n",
    "                    patch=patch_hs[(layer, patch_tok_idx)],\n",
    "                )\n",
    "            )\n",
    "\n",
    "        int_hs = get_hs(\n",
    "            mt=mt,\n",
    "            input=clean_tokenized,\n",
    "            locations=[logit_location],\n",
    "            patches=patch_spec,\n",
    "            return_dict=True,\n",
    "        )\n",
    "        int_logits = int_hs[logit_location]\n",
    "\n",
    "        int_pred, int_track = interpret_logits(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            logits=int_logits,\n",
    "            interested_tokens=interestested_tokens,\n",
    "        )\n",
    "\n",
    "        logger.debug(f\"Layer {layer}: int_pred={[str(pred) for pred in int_pred]}\")\n",
    "        layer_wise_patching_results[layer] = {\n",
    "            \"int_pred\": int_pred,\n",
    "            \"int_track\": int_track,\n",
    "        }\n",
    "\n",
    "    ret[\"layer_wise_patching_results\"] = layer_wise_patching_results\n",
    "    return ret\n",
    "\n",
    "\n",
    "patching_result = layer_wise_patching(\n",
    "    mt=mt,\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=clean_sample,\n",
    "    map_indices={-2: -2, -1: -1},\n",
    "    consider_ques_pos=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "081ae5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 64\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=N_DISTRACTORS,\n",
    "        clean_n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0a0463e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "results = []\n",
    "for clean, patch in tqdm(validation_set):\n",
    "    result = layer_wise_patching(\n",
    "        mt=mt,\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "        map_indices={-2: -2, -1: -1},\n",
    "        consider_ques_pos=True,\n",
    "    )\n",
    "    results.append(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05c376d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = [patching_result]\n",
    "\n",
    "scores = {token_type: [] for token_type in results[0][\"track_tokens\"].keys()}\n",
    "causality = []\n",
    "for result in results:\n",
    "    clean_track = result[\"clean_track\"]\n",
    "    patch_track = result[\"patch_track\"]\n",
    "\n",
    "    for token_type in scores.keys():\n",
    "        layerwise_scores = []\n",
    "        token_id = result[\"track_tokens\"][token_type]\n",
    "        for layer_idx in range(mt.n_layer):\n",
    "            score = result[\"layer_wise_patching_results\"][mt.layer_names[layer_idx]][\"int_track\"][token_id][1].logit\n",
    "            layerwise_scores.append(score)\n",
    "        scores[token_type].append(layerwise_scores)\n",
    "    \n",
    "    predicate_target = result[\"track_tokens\"][\"predicate_target\"] \n",
    "    layer_wise_causality = []\n",
    "    for layer_idx in range(mt.n_layer):\n",
    "        int_track = result[\"layer_wise_patching_results\"][mt.layer_names[layer_idx]][\"int_track\"]\n",
    "        is_causal = float(list(int_track.keys())[0] == predicate_target)\n",
    "        layer_wise_causality.append(is_causal)\n",
    "    causality.append(layer_wise_causality)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119dfaa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Create a figure with two vertically stacked subplots (shared x-axis)\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 35\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(\n",
    "    2, 1, figsize=(12, 8), sharex=True, gridspec_kw={\"height_ratios\": [3, 1]}\n",
    ")\n",
    "colors = {\n",
    "    \"predicate_target\": \"blue\",\n",
    "    \"clean_ans\": \"green\",\n",
    "    \"patch_ans\": \"red\",\n",
    "    \"patch_position\": \"purple\",\n",
    "    \"random_distractor\": \"gray\",\n",
    "}\n",
    "\n",
    "# Plot logits on the first panel\n",
    "lines = []\n",
    "for token_type, layerwise_scores_list in scores.items():\n",
    "    mean_scores = np.mean(layerwise_scores_list, axis=0)\n",
    "    sterr_scores = np.std(layerwise_scores_list, axis=0) / np.sqrt(\n",
    "        len(layerwise_scores_list)\n",
    "    )\n",
    "    line, = ax1.plot(mean_scores, label=f\"{token_type}\", alpha=0.8, color=colors[token_type])\n",
    "    ax1.fill_between(\n",
    "        range(len(mean_scores)),\n",
    "        mean_scores - sterr_scores,\n",
    "        mean_scores + sterr_scores,\n",
    "        alpha=0.1,\n",
    "        color=colors[token_type],\n",
    "    )\n",
    "    lines.append(line)\n",
    "\n",
    "ax1.set_ylabel(\"Logit(x)\")\n",
    "ax1.set_title(f\"Residual | {mt.name.split('/')[-1]}\")\n",
    "\n",
    "# Place the legend horizontally on top of the first panel\n",
    "ax1.legend(\n",
    "    handles=lines,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.1),\n",
    "    ncol=len(scores),\n",
    "    frameon=False,\n",
    "    fontsize=\"medium\"\n",
    ")\n",
    "\n",
    "# Plot causality on the second panel\n",
    "mean_causality = np.mean(causality, axis=0)\n",
    "sterr_causality = np.std(causality, axis=0) / np.sqrt(len(causality))\n",
    "ax2.plot(\n",
    "    mean_causality,\n",
    "    label=\"causality\",\n",
    "    color=colors[\"predicate_target\"],\n",
    "    linestyle=\"--\",\n",
    "    alpha=0.9,\n",
    "    linewidth=2,\n",
    ")\n",
    "ax2.fill_between(\n",
    "    range(len(mean_causality)),\n",
    "    mean_causality - sterr_causality,\n",
    "    mean_causality + sterr_causality,\n",
    "    color=colors[\"predicate_target\"],\n",
    "    alpha=0.1,\n",
    ")\n",
    "ax2.set_xlabel(\"Layer\")\n",
    "ax2.set_ylabel(\"Causality\")\n",
    "ax2.set_ylim(-0.1, 1.1)\n",
    "ax2.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "save_dir = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"residual\")\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_dir, f\"template_{prompt_template_idx}.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83cf07ce",
   "metadata": {},
   "source": [
    "## Patching to check if there is an answer flag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a4848a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "import random\n",
    "\n",
    "\n",
    "def get_counterfactual_sample_with_answer_flag(\n",
    "    mt: ModelandTokenizer,\n",
    "    select_task: SelectOneTask,\n",
    "    clean_category: str = \"fruit\",\n",
    "    patch_category: str = \"vehicle\",\n",
    "    flag_category: str = \"animal\",\n",
    "    n_other_distractors: int = 2,\n",
    "    prompt_template_idx=2,\n",
    "    option_style: str = OPTION_STYLE,\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    retry_count: int = 0,\n",
    "):\n",
    "    distractor_categories = random.sample(\n",
    "        list(\n",
    "            set(select_task.category_wise_examples.keys())\n",
    "            - {clean_category, patch_category, flag_category}\n",
    "        ),\n",
    "        k=n_other_distractors,\n",
    "    )\n",
    "\n",
    "    obj_to_category_map = {}\n",
    "    clean_obj = random.choice(select_task.category_wise_examples[clean_category])\n",
    "    patch_obj = random.choice(select_task.category_wise_examples[patch_category])\n",
    "    flag_obj = random.choice(select_task.category_wise_examples[flag_category])\n",
    "\n",
    "    obj_to_category_map[clean_obj] = clean_category\n",
    "    obj_to_category_map[patch_obj] = patch_category\n",
    "    obj_to_category_map[flag_obj] = flag_category\n",
    "\n",
    "    logger.info(f\"{clean_obj=}, {patch_obj=}, {flag_obj=}\")\n",
    "\n",
    "    options = [clean_obj, patch_obj, flag_obj]\n",
    "\n",
    "    for category in distractor_categories:\n",
    "        obj = random.choice(select_task.category_wise_examples[category])\n",
    "        obj_to_category_map[obj] = category\n",
    "        options.append(obj)\n",
    "\n",
    "    random.shuffle(options)\n",
    "\n",
    "    random_obj = random.choice(list(set(options) - {clean_obj, patch_obj, flag_obj}))\n",
    "    logger.info(f\"{random_obj=}\")\n",
    "    logger.info(f\"{options=}\")\n",
    "\n",
    "    clean_sample = SelectionSample(\n",
    "        obj=clean_obj,\n",
    "        obj_idx=options.index(clean_obj),\n",
    "        options=options,\n",
    "        answer=clean_obj,\n",
    "        category=clean_category,\n",
    "        ans_token_id=get_first_token_id(\n",
    "            name=clean_obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "        prompt_template=select_task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=option_style,\n",
    "    )\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        obj=patch_obj,\n",
    "        obj_idx=options.index(patch_obj),\n",
    "        options=options,\n",
    "        answer=patch_obj,\n",
    "        category=patch_category,\n",
    "        ans_token_id=get_first_token_id(\n",
    "            name=patch_obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "        prompt_template=select_task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=option_style,\n",
    "    )\n",
    "\n",
    "    # get a sample where the flag obj is the answer\n",
    "    flag_sample = select_task.get_random_sample(\n",
    "        mt=mt,\n",
    "        category=flag_category,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=option_style,\n",
    "        n_distractors=len(options) - 1,\n",
    "    )\n",
    "    flag_sample.options[flag_sample.options.index(flag_sample.obj)] = flag_obj\n",
    "    flag_sample.obj = flag_obj\n",
    "    flag_sample.ans_token_id = get_first_token_id(\n",
    "        name=flag_obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    "    )\n",
    "\n",
    "    distractor_samples = {}\n",
    "    # get samples where the distractor obj is NOT the answer\n",
    "    for distractor_obj, distractor_category in obj_to_category_map.items():\n",
    "        if distractor_obj == flag_obj:\n",
    "            # No need\n",
    "            continue\n",
    "        other_category = random.choice(\n",
    "            list(\n",
    "                set(select_task.category_wise_examples.keys())\n",
    "                - {clean_category, patch_category, flag_category, distractor_category}\n",
    "            )\n",
    "        )\n",
    "        obj_idx = random.choice(range(len(options)))\n",
    "        distractor_sample = select_task.get_random_sample(\n",
    "            mt=mt,\n",
    "            category=other_category,\n",
    "            obj_idx=obj_idx,\n",
    "            prompt_template_idx=prompt_template_idx,\n",
    "            option_style=option_style,\n",
    "            n_distractors=len(options) - 1,\n",
    "            exclude_distractor_categories=[\n",
    "                clean_category,\n",
    "                patch_category,\n",
    "                flag_category,\n",
    "                distractor_category,\n",
    "            ],\n",
    "            insert_distractor=[\n",
    "                (\n",
    "                    distractor_obj,\n",
    "                    random.choice(list(set(range(len(options))) - {obj_idx})),\n",
    "                )\n",
    "            ],\n",
    "        )\n",
    "        distractor_samples[distractor_obj] = distractor_sample\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        for sample in [clean_sample, patch_sample, flag_sample]:\n",
    "            is_correct, pred, track_objs = verify_correct_option(\n",
    "                mt=mt,\n",
    "                input=sample.prompt(),\n",
    "                target=sample.ans_token_id,\n",
    "                options=sample.options,\n",
    "            )\n",
    "\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f\"\"\"Sample = {sample}\n",
    "Top prediction {track_objs[list(track_objs.keys())[0]]} does not match the object {sample.obj}[{sample.ans_token_id}, \"{mt.tokenizer.decode(sample.ans_token_id)}\"].\n",
    "Retry count: {retry_count + 1}. Retrying ...\n",
    "\"\"\"\n",
    "                )\n",
    "                return get_counterfactual_sample_with_answer_flag(\n",
    "                    mt=mt,\n",
    "                    select_task=select_task,\n",
    "                    clean_category=clean_category,\n",
    "                    patch_category=patch_category,\n",
    "                    flag_category=flag_category,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    retry_count=retry_count + 1,\n",
    "                )\n",
    "\n",
    "        # for the distractors just make sure that the distractor obj is not the answer\n",
    "        for distractor_obj, distractor_sample in distractor_samples.items():\n",
    "            pred, track_objs = predict_next_token(\n",
    "                mt=mt,\n",
    "                inputs=distractor_sample.prompt(),\n",
    "                token_of_interest=[\n",
    "                    get_first_token_id(opt, mt.tokenizer, prefix=\" \")\n",
    "                    for opt in distractor_sample.options\n",
    "                ],\n",
    "            )\n",
    "            track_objs = track_objs[0]\n",
    "            if list(track_objs.keys())[0] == get_first_token_id(\n",
    "                distractor_obj, mt.tokenizer, prefix=\" \"\n",
    "            ):\n",
    "                logger.error(\n",
    "                    f\"\"\"Sample = {distractor_sample}\n",
    "Top prediction {track_objs[list(track_objs.keys())[0]]} matches the distractor object {distractor_obj}[{distractor_sample.ans_token_id}, \"{mt.tokenizer.decode(distractor_sample.ans_token_id)}\"].\n",
    "Retry count: {retry_count + 1}. Retrying ...\n",
    "\"\"\"\n",
    "                )\n",
    "                return get_counterfactual_sample_with_answer_flag(\n",
    "                    mt=mt,\n",
    "                    select_task=select_task,\n",
    "                    clean_category=clean_category,\n",
    "                    patch_category=patch_category,\n",
    "                    flag_category=flag_category,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    retry_count=retry_count + 1,\n",
    "                )\n",
    "\n",
    "    return clean_sample, patch_sample, flag_sample, distractor_samples\n",
    "\n",
    "\n",
    "clean_sample, patch_sample, flag_sample, distractor_samples = (\n",
    "    get_counterfactual_sample_with_answer_flag(\n",
    "        mt=mt,\n",
    "        select_task=select_task,\n",
    "        clean_category=\"fruit\",\n",
    "        patch_category=\"vehicle\",\n",
    "        flag_category=\"animal\",\n",
    "        n_other_distractors=2,\n",
    "        prompt_template_idx=2,\n",
    "        option_style=OPTION_STYLE,\n",
    "        retry_count=0,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    ")\n",
    "\n",
    "print(\"clean\")\n",
    "print(clean_sample.prompt(), \">>\", mt.tokenizer.decode(clean_sample.ans_token_id))\n",
    "\n",
    "print(\"patch\")\n",
    "print(patch_sample.prompt(), \">>\", mt.tokenizer.decode(patch_sample.ans_token_id))\n",
    "\n",
    "print(\"flag\")\n",
    "print(flag_sample.prompt(), \">>\", mt.tokenizer.decode(flag_sample.ans_token_id))\n",
    "\n",
    "print(\"------------------------------\")\n",
    "for distractor_obj, distractor_sample in distractor_samples.items():\n",
    "    print(\"distractor:\", distractor_obj)\n",
    "    print(\n",
    "        distractor_sample.prompt(),\n",
    "        \">>\",\n",
    "        mt.tokenizer.decode(distractor_sample.ans_token_id),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d7b6ca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import find_token_range\n",
    "\n",
    "\n",
    "def locate_with_delim(prompt, option):\n",
    "    st = prompt.index(option)\n",
    "    return prompt[st : st + len(option) + 1]\n",
    "\n",
    "\n",
    "def layer_wise_patch_with_ablating_ans(\n",
    "    mt: ModelandTokenizer,\n",
    "    clean_sample: SelectionSample,\n",
    "    patch_sample: SelectionSample,\n",
    "    flag_sample: SelectionSample,\n",
    "    option_to_sample: dict[str, SelectionSample],\n",
    "    pred_token_indices=[-2, -1],\n",
    "    ablate_ans_flag: bool = True,\n",
    "):\n",
    "    interested_tokens = {\n",
    "        \"clean_ans\": get_first_token_id(clean_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"patch_ans\": get_first_token_id(patch_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"flag_ans\": get_first_token_id(flag_sample.obj, mt.tokenizer, prefix=\" \"),\n",
    "        \"random_obj\": get_first_token_id(\n",
    "            random.choice(\n",
    "                list(\n",
    "                    set(clean_sample.options)\n",
    "                    - {clean_sample.obj, patch_sample.obj, flag_sample.obj}\n",
    "                )\n",
    "            ),\n",
    "            mt.tokenizer,\n",
    "            prefix=\" \",\n",
    "        ),\n",
    "    }\n",
    "    clean_tokenized = prepare_input(\n",
    "        tokenizer=mt, prompts=clean_sample.prompt(), return_offsets_mapping=True\n",
    "    )\n",
    "    clean_offset_mapping = clean_tokenized.pop(\"offset_mapping\")[0]\n",
    "\n",
    "    logit_location = (mt.lm_head_name, -1)\n",
    "    # clean_run\n",
    "    clean_logits = get_hs(\n",
    "        mt=mt,\n",
    "        input=clean_tokenized,\n",
    "        locations=[logit_location],\n",
    "        return_dict=False,\n",
    "    ).squeeze()\n",
    "    clean_pred, clean_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=clean_logits,\n",
    "        interested_tokens=interested_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "    logger.debug(f\"clean_track={clean_track}\")\n",
    "\n",
    "    # patch_run\n",
    "    patch_tokenized = prepare_input(\n",
    "        tokenizer=mt, prompts=patch_sample.prompt(), return_offsets_mapping=True\n",
    "    )\n",
    "    patch_offset_mapping = patch_tokenized.pop(\"offset_mapping\")[0]\n",
    "    pred_locations = list(product(mt.layer_names, pred_token_indices))\n",
    "    pred_hs_from_patch = get_hs(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized,\n",
    "        locations=pred_locations + [logit_location],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    patch_logits = pred_hs_from_patch[logit_location]\n",
    "    patch_pred, patch_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=patch_logits,\n",
    "        interested_tokens=interested_tokens.values(),\n",
    "    )\n",
    "    logger.debug(f\"patch_pred={[str(pred) for pred in patch_pred]}\")\n",
    "    logger.debug(f\"patch_track={patch_track}\")\n",
    "\n",
    "    option_patches = []\n",
    "    # get flag hs\n",
    "    for option, sample in option_to_sample.items():\n",
    "        assert option in clean_sample.options\n",
    "        assert option in sample.options\n",
    "        flag_tokenized = prepare_input(\n",
    "            prompts=sample.prompt(), tokenizer=mt, return_offsets_mapping=True\n",
    "        )\n",
    "        flag_offset_mapping = flag_tokenized.pop(\"offset_mapping\")[0]\n",
    "        \n",
    "        # get flag option token range\n",
    "        flag_opt_range = find_token_range(\n",
    "            string=sample.prompt(),\n",
    "            substring=locate_with_delim(sample.prompt(), option),\n",
    "            tokenizer=mt.tokenizer,\n",
    "            offset_mapping=flag_offset_mapping,\n",
    "        )\n",
    "        logger.debug(f'flag_opt_range={flag_opt_range}, {mt.tokenizer.decode(flag_tokenized.input_ids[0][range(*flag_opt_range)])}')\n",
    "\n",
    "        # get clean option token range\n",
    "        clean_opt_range = find_token_range(\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=locate_with_delim(clean_sample.prompt(), option),\n",
    "            tokenizer=mt.tokenizer,\n",
    "            offset_mapping=clean_offset_mapping,\n",
    "        )\n",
    "        logger.debug(f'clean_opt_range={clean_opt_range}, {mt.tokenizer.decode(clean_tokenized.input_ids[0][range(*clean_opt_range)])}')\n",
    "        assert (\n",
    "            flag_opt_range[1] - flag_opt_range[0]\n",
    "            == clean_opt_range[1] - clean_opt_range[0]\n",
    "        )\n",
    "\n",
    "        flag_opt_locations = list(\n",
    "            product(mt.layer_names, list(range(*flag_opt_range)))\n",
    "        )\n",
    "        flag_hs_from_flag = get_hs(\n",
    "            mt=mt,\n",
    "            input=flag_tokenized,\n",
    "            locations=flag_opt_locations + [logit_location],\n",
    "            return_dict=True,\n",
    "        )\n",
    "        flag_pred, flag_track = interpret_logits(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            logits=flag_hs_from_flag[logit_location],\n",
    "            interested_tokens=interested_tokens.values(),\n",
    "        )\n",
    "        logger.debug(f\"flag_pred={[str(pred) for pred in flag_pred]}\")\n",
    "        logger.debug(f\"flag_track={flag_track}\")\n",
    "        for flag_tok_idx, clean_tok_idx in zip(\n",
    "            range(*flag_opt_range), range(*clean_opt_range)\n",
    "        ):\n",
    "            option_patches.extend(\n",
    "                [\n",
    "                    PatchSpec(\n",
    "                        location=(layer, clean_tok_idx),\n",
    "                        patch=flag_hs_from_flag[(layer, flag_tok_idx)],\n",
    "                    )\n",
    "                    for layer in mt.layer_names\n",
    "                ]\n",
    "            )\n",
    "\n",
    "\n",
    "    layerwise_patching_results = {}\n",
    "    for layer in mt.layer_names:\n",
    "        patch_spec = []\n",
    "        for tok_idx in pred_token_indices:\n",
    "            patch_spec.append(\n",
    "                PatchSpec(\n",
    "                    location=(layer, tok_idx),\n",
    "                    patch=pred_hs_from_patch[(layer, tok_idx)],\n",
    "                )\n",
    "            )\n",
    "\n",
    "        int_hs = get_hs(\n",
    "            mt=mt,\n",
    "            input=clean_tokenized,\n",
    "            locations=[logit_location],\n",
    "            patches=patch_spec + option_patches if ablate_ans_flag else patch_spec,\n",
    "            return_dict=True,\n",
    "        )\n",
    "        int_logits = int_hs[logit_location]\n",
    "        int_pred, int_track = interpret_logits(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            logits=int_logits,\n",
    "            interested_tokens=interested_tokens.values(),\n",
    "        )\n",
    "\n",
    "        logger.debug(f\"Layer {layer}: int_pred={[str(pred) for pred in int_pred]}\")\n",
    "\n",
    "        layerwise_patching_results[layer] = {\n",
    "            \"int_pred\": int_pred,\n",
    "            \"int_track\": int_track,\n",
    "        }\n",
    "\n",
    "    return {\n",
    "        \"track_tokens\": interested_tokens,\n",
    "        \"clean_pred\": clean_pred,\n",
    "        \"clean_track\": clean_track,\n",
    "        \"patch_pred\": patch_pred,\n",
    "        \"patch_track\": patch_track,\n",
    "        \"flag_pred\": flag_pred,\n",
    "        \"flag_track\": flag_track,\n",
    "        \"layerwise_patching_results\": layerwise_patching_results,\n",
    "    }\n",
    "\n",
    "opt_to_sample = {opt: distractor_samples[opt] for opt in distractor_samples}\n",
    "opt_to_sample[flag_sample.obj] = flag_sample\n",
    "patching_result = layer_wise_patch_with_ablating_ans(\n",
    "    mt=mt,\n",
    "    clean_sample=clean_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    flag_sample=flag_sample,\n",
    "    option_to_sample=opt_to_sample,\n",
    "    pred_token_indices=[-2, -1],\n",
    "    ablate_ans_flag=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5de1d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(clean_sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e25cdaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "limit = 32\n",
    "prompt_template_idx = 3\n",
    "\n",
    "test_set = []\n",
    "while len(test_set) < limit:\n",
    "    print(f\"sample {len(test_set)+1} / {limit}\")\n",
    "    clean_category = random.choice(list(select_task.category_wise_examples.keys()))\n",
    "    patch_category = random.choice(\n",
    "        list(\n",
    "            set(select_task.category_wise_examples.keys()) - {clean_category}\n",
    "        )\n",
    "    )\n",
    "    flag_category = random.choice(\n",
    "        list(\n",
    "            set(select_task.category_wise_examples.keys())\n",
    "            - {clean_category, patch_category}\n",
    "        )\n",
    "    )\n",
    "    clean, patch, flag, opt_to_sample = get_counterfactual_sample_with_answer_flag(\n",
    "        mt=mt,\n",
    "        select_task=select_task,\n",
    "        clean_category=clean_category,\n",
    "        patch_category=patch_category,\n",
    "        flag_category=flag_category,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        retry_count=0,\n",
    "    )\n",
    "    test_set.append((clean, patch, flag, opt_to_sample))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89451743",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample, patch_sample, flag_sample, opt_to_sample  = test_set[0]\n",
    "print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "print(patch_sample.prompt(), \">>\", patch_sample.obj)\n",
    "print(flag_sample.prompt(), \">>\", flag_sample.obj)\n",
    "\n",
    "print(\"------------------------------\")\n",
    "for distractor_obj, distractor_sample in opt_to_sample.items():\n",
    "    print(\n",
    "        distractor_obj,\n",
    "        distractor_sample.prompt(),\n",
    "        \">>\",\n",
    "        distractor_sample.obj,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80e9cbbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "results = []\n",
    "\n",
    "###########################\n",
    "ablate_ans_flag = True\n",
    "############################\n",
    "\n",
    "for clean, patch, flag, opt_to_sample in tqdm(test_set):\n",
    "    opt_to_sample = {opt: opt_to_sample[opt] for opt in opt_to_sample}\n",
    "    opt_to_sample[flag.obj] = flag\n",
    "    result = layer_wise_patch_with_ablating_ans(\n",
    "        mt=mt,\n",
    "        clean_sample=clean,\n",
    "        patch_sample=patch,\n",
    "        flag_sample=flag,\n",
    "        option_to_sample=opt_to_sample,\n",
    "        pred_token_indices=[-2, -1],\n",
    "        ablate_ans_flag=ablate_ans_flag,\n",
    "    )\n",
    "    results.append(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "824c0323",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = [patching_result]\n",
    "# result\n",
    "\n",
    "scores = {token_type: [] for token_type in results[0][\"track_tokens\"].keys()}\n",
    "for result in results:\n",
    "    for token_type in scores.keys():\n",
    "        layerwise_scores = []\n",
    "        token_id = result[\"track_tokens\"][token_type]\n",
    "        for layer_idx in range(mt.n_layer):\n",
    "            score = result[\"layerwise_patching_results\"][mt.layer_names[layer_idx]][\"int_track\"][token_id][1].logit\n",
    "            layerwise_scores.append(score)\n",
    "        scores[token_type].append(layerwise_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13b81806",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "colors = {\n",
    "    \"clean_ans\": \"green\",\n",
    "    \"patch_ans\": \"red\",\n",
    "    \"flag_ans\": \"purple\",\n",
    "    \"random_obj\": \"gray\",\n",
    "}\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "for token_type, layerwise_scores_list in scores.items():\n",
    "    # Compute mean and std deviation across results for each layer\n",
    "    mean_scores = np.mean(layerwise_scores_list, axis=0)\n",
    "    sterr_scores = np.std(layerwise_scores_list, axis=0) / np.sqrt(\n",
    "        len(layerwise_scores_list)\n",
    "    )\n",
    "\n",
    "    plt.plot(\n",
    "        mean_scores,\n",
    "        label=f\"{token_type}\",\n",
    "        alpha=0.7,\n",
    "        color=colors.get(token_type, None),\n",
    "    )\n",
    "    plt.fill_between(\n",
    "        range(len(mean_scores)),\n",
    "        mean_scores - sterr_scores,\n",
    "        mean_scores + sterr_scores,\n",
    "        alpha=0.1,\n",
    "        color=colors.get(token_type, None),\n",
    "    )\n",
    "\n",
    "file_name = f\"residual_{prompt_template_idx}\"\n",
    "if ablate_ans_flag:\n",
    "    file_name += \"_flag_ablate\"\n",
    "file_name += \".pdf\"\n",
    "\n",
    "save_dir = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"residual\")\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Logit(x)\")\n",
    "plt.title(f\"{file_name} | {mt.name.split('/')[-1]}\")\n",
    "# plt.title(f\"Residual + Ablate Flag | {mt.name.split('/')[-1]}\")\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(save_dir, file_name), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c8c57b8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c67e1043",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a3f5418",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13266c4f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
