{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import re\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from matplotlib.transforms import Bbox\n",
    "from sklearn.metrics import (\n",
    "    ConfusionMatrixDisplay,\n",
    "    classification_report,\n",
    "    confusion_matrix,\n",
    ")\n",
    "from tqdm import tqdm\n",
    "\n",
    "from ar import ActivationReasoning, LogicConfig\n",
    "from ar.config import LogicConfig\n",
    "from ProntoQA_logic import create_rules\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Eval Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_between(label: str, value: str, width: int, min_gap: int = 1) -> str:\n",
    "    \"\"\"Return `label + spaces + value` with total length == width.\"\"\"\n",
    "    min_needed = len(label) + len(value) + min_gap\n",
    "    if width < min_needed:\n",
    "        # Not enough room: truncate the label to fit, keep value intact.\n",
    "        lab_len = max(0, width - len(value) - min_gap)\n",
    "        label = label[:lab_len]\n",
    "        gap = min_gap\n",
    "    else:\n",
    "        gap = width - len(label) - len(value)\n",
    "    return f\"{label}{' ' * gap}{value}\"\n",
    "\n",
    "\n",
    "def make_replacer(label: str, value: str):\n",
    "    def replacer(m: re.Match) -> str:\n",
    "        width = len(m.group(0))  # exact length of the matched chunk\n",
    "        return pad_between(label, value, width)\n",
    "\n",
    "    return replacer\n",
    "\n",
    "def plot_report_and_cm_with_textbox(labels, predictions, title, cm_labels, filename):\n",
    "    \"\"\"\n",
    "    Plots the classification report as a text box and a confusion matrix in a single figure.\n",
    "\n",
    "    Args:\n",
    "        labels (list): The true labels.\n",
    "        predictions (list): The predicted labels.\n",
    "        title (str): The main title for the plot.\n",
    "        cm_labels (list): Labels for the confusion matrix.\n",
    "        filename (str): The name of the PDF file to save.\n",
    "    \"\"\"\n",
    "    # Create the figure and two subplots\n",
    "    fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(12, 6))\n",
    "    fig.suptitle(title, fontsize=16)\n",
    "\n",
    "    # --- Plot the Confusion Matrix on the first subplot (ax1) ---\n",
    "    cm = confusion_matrix(labels, predictions, labels=cm_labels)\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=cm_labels)\n",
    "    disp.plot(ax=ax1)\n",
    "    ax1.set_title(\"Confusion Matrix\")\n",
    "\n",
    "    # --- Generate the Classification Report as a string ---\n",
    "    report_string = classification_report(\n",
    "        labels, predictions, labels=cm_labels, zero_division=0.0, digits=4\n",
    "    )\n",
    "    pattern = r\"micro avg\\s+\\d+\\.\\d+\\s+\\d+\\.\\d+\\s+\\d+\\.\\d+\"\n",
    "    report_string = re.sub(\n",
    "        pattern,\n",
    "        make_replacer(\" accuracy\", f\"{(cm[0, 0] + cm[1, 1]) / len(labels):.4f}\"),\n",
    "        report_string,\n",
    "    )\n",
    "    # --- Plot the Classification Report as a text box on the second subplot (ax2) ---\n",
    "    # To get a monospaced font, which helps with alignment\n",
    "    plt.rcParams[\"font.family\"] = \"monospace\"\n",
    "\n",
    "    # Add the text to the subplot.\n",
    "    # The first two arguments (0.05, 0.95) are the x and y coordinates of the text box.\n",
    "    # We use a box for better visual separation.\n",
    "    props = dict(boxstyle=\"round,pad=0.5\", facecolor=\"wheat\", alpha=0.5)\n",
    "    ax2.text(\n",
    "        0.05,\n",
    "        0.95,\n",
    "        report_string,\n",
    "        transform=ax2.transAxes,\n",
    "        fontsize=10,\n",
    "        verticalalignment=\"top\",\n",
    "        bbox=props,\n",
    "    )\n",
    "\n",
    "    # Hide the axes of the second subplot to show only the text\n",
    "    ax2.axis(\"off\")\n",
    "    ax2.set_title(\"Classification Report\")\n",
    "\n",
    "    # Save the figure\n",
    "    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout to make room for suptitle\n",
    "    plt.savefig(filename, format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "    fig.savefig(\n",
    "        filename.replace(\".pdf\", \"\") + \"_CR.pdf\",\n",
    "        format=\"pdf\",\n",
    "        bbox_inches=Bbox([[0.25, 3.1], [5.0, 4.9]]\n",
    "        ),\n",
    "    )\n",
    "\n",
    "    fig.savefig(\n",
    "        filename.replace(\".pdf\", \"\") + \"_CM.pdf\",\n",
    "        format=\"pdf\",\n",
    "        bbox_inches=Bbox([[6.1, 0.1], [11.6, 5.1]]),\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(model, test_ds, do_al:bool=True, al_steering_factor:float=0.5, al_verbose:bool=False, do_base:bool=True, num_samples:int=2000):\n",
    "    Labels = []\n",
    "    AL_output = []\n",
    "    AL_output_rules = []\n",
    "    AL_results = []\n",
    "    AL_results_rules = []\n",
    "    AL_steering_results = []\n",
    "    AL_no_steering_results = []\n",
    "    BASE_output = []\n",
    "    BASE_results = []\n",
    "    tqdm_bar = tqdm(range(num_samples), desc=\"\", unit=\"iter\")\n",
    "\n",
    "    model_hyp = {\n",
    "        \"do_sample\": False,\n",
    "        \"temperature\": None,\n",
    "        \"top_k\": None,\n",
    "        \"top_p\": None,\n",
    "        \"max_new_tokens\": 1,\n",
    "    }\n",
    "\n",
    "    for i in tqdm_bar:\n",
    "        rules = create_rules(test_ds[\"questions+queries\"][i], verbose=False)\n",
    "        model.config.reasoner_rules_checking = \"complex\"\n",
    "        model._reset_reasoner(rules)\n",
    "\n",
    "        Labels.append(test_ds[\"answers\"][i].lower())\n",
    "        \n",
    "        if do_al:\n",
    "            model.config.steering_factor = al_steering_factor\n",
    "            out_AL = model._generate(\n",
    "                prompt=test_ds[\"queries\"][i : i + 1],\n",
    "                model_hyp=model_hyp,\n",
    "                verbose=al_verbose,\n",
    "                return_activated_rules=True,\n",
    "            )\n",
    "            if al_verbose:\n",
    "                print(f\"Output: {out_AL}, GT: {test_ds['answers'][i].lower()}\")\n",
    "            AL_output.append(out_AL[0][0].lower())\n",
    "            is_correct_al = test[\"answers\"][i].lower() in out_AL[0][0].lower()\n",
    "            AL_results.append(is_correct_al)\n",
    "\n",
    "            try:\n",
    "                is_correct_rule = test_ds[\"answers\"][i].lower() in out_AL[1][0][0].lower()\n",
    "            except:\n",
    "                is_correct_rule = False\n",
    "\n",
    "            AL_output_rules.append(out_AL[1][0][0].lower() if out_AL[1][0] else None)\n",
    "            AL_results_rules.append(is_correct_rule)\n",
    "\n",
    "            if is_correct_rule:\n",
    "                AL_steering_results.append(is_correct_al)\n",
    "            else:\n",
    "                AL_no_steering_results.append(is_correct_al)\n",
    "        \n",
    "        if do_base:\n",
    "            model.config.steering_factor = 0\n",
    "            out_B = model._generate(\n",
    "                prompt=[test_ds[\"questions\"][i] + \" \" + test_ds[\"queries\"][i]],\n",
    "                model_hyp=model_hyp,\n",
    "                verbose=False,\n",
    "                return_activated_rules=True,\n",
    "            )\n",
    "            BASE_results.append(test_ds[\"answers\"][i].lower() in out_B[0][0].lower())\n",
    "            BASE_output.append(out_B[0][0].lower())\n",
    "\n",
    "        # print(\"AL: \", out_AL, test[\"answers\"][i])\n",
    "        # print(\"Base: \", out_B, test[\"answers\"][i])\n",
    "        tqdm_bar.set_postfix(\n",
    "            BASE_Acc=f\"{sum(BASE_results) / len(BASE_results) if len(BASE_results) != 0 else 0:.2f}\",\n",
    "            AL_Acc=f\"{sum(AL_results) / len(AL_results) if len(AL_results) != 0 else 0:.2f}\",\n",
    "            Rule_Acc=f\"{sum(AL_results_rules) / len(AL_results_rules) if len(AL_results_rules) != 0 else 0:.2f}\",\n",
    "            AL_ACC_S=f\"{sum(AL_steering_results) / len(AL_steering_results) if len(AL_steering_results) != 0 else 0:.2f} of {len(AL_steering_results)}\",\n",
    "            AL_ACC_NoS=f\"{sum(AL_no_steering_results) / len(AL_no_steering_results) if len(AL_no_steering_results) != 0 else 0:.2f} of {len(AL_no_steering_results)}\",\n",
    "        )\n",
    "    return (\n",
    "        Labels,\n",
    "        AL_output,\n",
    "        AL_output_rules,\n",
    "        AL_results,\n",
    "        AL_results_rules,\n",
    "        AL_steering_results,\n",
    "        AL_no_steering_results,\n",
    "        BASE_output,\n",
    "        BASE_results,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LLama3.1 8B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    \"wumpus\",\n",
    "    \"yumpus\",\n",
    "    \"zumpus\",\n",
    "    \"rompus\",\n",
    "    \"numpus\",\n",
    "    \"vumpus\",\n",
    "    \"impus\",\n",
    "    \"shumpus\",\n",
    "    \"terpus\",\n",
    "    \"sterpus\",\n",
    "    \"orpus\",\n",
    "    \"lempus\",\n",
    "    \"umpus\",\n",
    "    \"grimpus\",\n",
    "    \"gorpus\",\n",
    "    \"rimpus\",\n",
    "    \"lorpus\",\n",
    "    \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    \"dumpus\",\n",
    "    \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "\n",
    "llama3_1 = {\n",
    "    \"model_name\": \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "    \"sae_name\": \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "    \"layer\": 23,\n",
    "    \"hookpoint\": \"layers.23\",\n",
    "}\n",
    "\n",
    "llama3 = {\n",
    "    \"model_name\": \"meta-llama/Meta-Llama-3-8B\",\n",
    "    \"sae_name\": \"EleutherAI/sae-llama-3-8b-32x-v2\",\n",
    "    \"layer\": 11,\n",
    "    \"hookpoint\": \"layers.11\",\n",
    "}\n",
    "\n",
    "\n",
    "experiment_dir = \"output/experiments/prontoqa/v0.9\"\n",
    "os.makedirs(experiment_dir, exist_ok=True)\n",
    "config = LogicConfig(\n",
    "    search_concept_type=\"word\",\n",
    "    search_concept_token=\"all\",\n",
    "    search_strategy=\"top_k\",\n",
    "    search_top_k_order=\"unique_only\",\n",
    "    search_top_k=10,\n",
    "    detection_top_k_output=10,\n",
    "    detection_threshold=0,\n",
    "    detection_top_k_concepts=4,\n",
    "    steering_factor=0.5,\n",
    "    steering_top_k_rule=1,\n",
    "    steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    steering_norm=2,\n",
    "    steering_methodology=\"mean_shift\",\n",
    "    reasoner_rules_checking=\"complex\",\n",
    ")\n",
    "ar_model = ActivationReasoning(\n",
    "    rules={},\n",
    "    concepts=concepts,\n",
    "    config=config,\n",
    "    cache_dir=f\"{experiment_dir}/sae_latents\",\n",
    "    **llama3_1,\n",
    "    verbose=False,\n",
    ")\n",
    "ar_model.concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    # \"wumpus\",\n",
    "    \"yumpus\",\n",
    "    \"zumpus\",\n",
    "    \"rompus\",\n",
    "    \"numpus\",\n",
    "    \"vumpus\",\n",
    "    \"impus\",\n",
    "    \"shumpus\",\n",
    "    # \"terpus\",\n",
    "    \"sterpus\",\n",
    "    # \"orpus\",\n",
    "    \"lempus\",\n",
    "    # \"umpus\",\n",
    "    \"grimpus\",\n",
    "    \"gorpus\",\n",
    "    # \"rimpus\",\n",
    "    \"lorpus\",\n",
    "    \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    \"dumpus\",\n",
    "    \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "logic_config = LogicConfig(\n",
    "    detection_top_k_output=2,  # Maybe try 3 or 4\n",
    "    detection_top_k_concepts=4,\n",
    "    steering_factor=0.5,\n",
    "    steering_top_k_rule=1,\n",
    "    steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    steering_norm=2,\n",
    "    steering_methodology=\"mean_shift\",\n",
    "    reasoner_rules_checking=\"complex\",\n",
    ")\n",
    "ar_model.configure(config=logic_config)\n",
    "ar_model._al_concepts.concept_dict[\"true\"] = {\n",
    "    \"indices\": [\n",
    "        225923,\n",
    "        216333,\n",
    "        134579,\n",
    "        104019,\n",
    "        88141,\n",
    "        235961,\n",
    "        244936,\n",
    "        172278,\n",
    "        208219,\n",
    "        186291,\n",
    "    ],\n",
    "    \"weights\": [\n",
    "        4.601840496063232,\n",
    "        3.9762470722198486,\n",
    "        3.118572235107422,\n",
    "        2.716679096221924,\n",
    "        2.643275499343872,\n",
    "        1.9865232706069946,\n",
    "        1.3156009912490845,\n",
    "        2.0025057792663574,\n",
    "        1.3420261144638062,\n",
    "        0.9111782908439636,\n",
    "    ],\n",
    "}\n",
    "ar_model._al_concepts.concept_dict[\"false\"] = {\n",
    "    \"indices\": [\n",
    "        172278,\n",
    "        62182,\n",
    "        106143,\n",
    "        101823,\n",
    "        244865,\n",
    "        75729,\n",
    "        168504,\n",
    "        174992,\n",
    "        172278,\n",
    "        208219,\n",
    "        # 186291\n",
    "    ],\n",
    "    \"weights\": [\n",
    "        9.81,\n",
    "        3.9630446434020996,\n",
    "        2.7334697246551514,\n",
    "        1.8760185241699219,\n",
    "        1.6498314142227173,\n",
    "        1.4435627460479736,\n",
    "        1.056015968322754,\n",
    "        0.8861362934112549,\n",
    "        1.8535783290863037,\n",
    "        1.7445926666259766,\n",
    "        # 0.9542834758758545,\n",
    "    ],\n",
    "}\n",
    "ar_model.reset_conv()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    \"./data/ProntoQA-5hop_0shot_random_10k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test[\"questions+queries\"][0], test[\"answers\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_rules(test[\"questions+queries\"][0] + \" True\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = f\"./{experiment_dir}/\"\n",
    "Labels, AL_output, AL_output_rules, AL_results, AL_results_rules, AL_steering_results, AL_no_steering_results, BASE_output, BASE_results = eval(ar_model, test, do_al=True, do_base=True, num_samples=2000)\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_5hop_AL_reasoning_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_5hop_AL_detection_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_5hop_LLM_reasoning_Llama31-8B.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    \"./data/ProntoQA-3hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(ar_model, test, do_al=True, do_base=True, num_samples=2000)\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_reasoning_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_detection_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_LLM_reasoning_Llama31-8B.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    \"./data/ProntoQA-1hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(ar_model, test, do_al=True, do_base=True, num_samples=2000)\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_reasoning_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_detection_Llama31-8B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_LLM_reasoning_Llama31-8B.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LLama3.1 8B Instruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ar_model.cleanup()\n",
    "ar_model.cleanup()\n",
    "concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    \"wumpus\",\n",
    "    \"yumpus\",\n",
    "    \"zumpus\",\n",
    "    \"rompus\",\n",
    "    \"numpus\",\n",
    "    \"vumpus\",\n",
    "    \"impus\",\n",
    "    \"shumpus\",\n",
    "    \"terpus\",\n",
    "    \"sterpus\",\n",
    "    \"orpus\",\n",
    "    \"lempus\",\n",
    "    \"umpus\",\n",
    "    \"grimpus\",\n",
    "    \"gorpus\",\n",
    "    \"rimpus\",\n",
    "    \"lorpus\",\n",
    "    \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    \"dumpus\",\n",
    "    \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "\n",
    "llama3_1 = {\n",
    "    \"model_name\": \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "    \"sae_name\": \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "    \"layer\": 23,\n",
    "    \"hookpoint\": \"layers.23\",\n",
    "}\n",
    "\n",
    "\n",
    "experiment_dir = \"output/experiments/prontoqa/v0.9\"\n",
    "os.makedirs(experiment_dir, exist_ok=True)\n",
    "config = LogicConfig(\n",
    "    search_concept_type=\"word\",\n",
    "    search_concept_token=\"all\",\n",
    "    search_strategy=\"top_k\",\n",
    "    search_top_k_order=\"unique_only\",\n",
    "    search_top_k=10,\n",
    "    detection_top_k_output=10,\n",
    "    detection_threshold=0,\n",
    "    detection_top_k_concepts=4,\n",
    "    steering_factor=0.5,\n",
    "    steering_top_k_rule=1,\n",
    "    steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    steering_norm=2,\n",
    "    steering_methodology=\"mean_shift\",\n",
    "    reasoner_rules_checking=\"complex\",\n",
    ")\n",
    "ar_model = ActivationReasoning(\n",
    "    rules={},\n",
    "    concepts=concepts,\n",
    "    config=config,\n",
    "    cache_dir=f\"{experiment_dir}/sae_latents\",\n",
    "    **llama3_1,\n",
    "    verbose=False,\n",
    ")\n",
    "ar_model.concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    # \"wumpus\",\n",
    "    \"yumpus\",\n",
    "    \"zumpus\",\n",
    "    \"rompus\",\n",
    "    \"numpus\",\n",
    "    \"vumpus\",\n",
    "    \"impus\",\n",
    "    \"shumpus\",\n",
    "    # \"terpus\",\n",
    "    \"sterpus\",\n",
    "    # \"orpus\",\n",
    "    \"lempus\",\n",
    "    # \"umpus\",\n",
    "    \"grimpus\",\n",
    "    \"gorpus\",\n",
    "    # \"rimpus\",\n",
    "    \"lorpus\",\n",
    "    \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    \"dumpus\",\n",
    "    \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "logic_config = LogicConfig(\n",
    "    detection_top_k_output=3,  \n",
    "    detection_top_k_concepts=2,\n",
    "    steering_factor=0.5,\n",
    "    steering_top_k_rule=1,\n",
    "    steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    steering_norm=2,\n",
    "    steering_methodology=\"mean_shift\",\n",
    "    reasoner_rules_checking=\"complex\",\n",
    ")\n",
    "ar_model.configure(config=logic_config)\n",
    "ar_model._al_concepts.concept_dict[\"true\"] = {\n",
    "    \"indices\": [\n",
    "        225923,\n",
    "        216333,\n",
    "        134579,\n",
    "        104019,\n",
    "        88141,\n",
    "        235961,\n",
    "        244936,\n",
    "        172278,\n",
    "        208219,\n",
    "        186291,\n",
    "    ],\n",
    "    \"weights\": [\n",
    "        4.601840496063232,\n",
    "        3.9762470722198486,\n",
    "        3.118572235107422,\n",
    "        2.716679096221924,\n",
    "        2.643275499343872,\n",
    "        1.9865232706069946,\n",
    "        1.3156009912490845,\n",
    "        2.0025057792663574,\n",
    "        1.3420261144638062,\n",
    "        0.9111782908439636,\n",
    "    ],\n",
    "}\n",
    "ar_model._al_concepts.concept_dict[\"false\"] = {\n",
    "    \"indices\": [\n",
    "        172278,\n",
    "        62182,\n",
    "        106143,\n",
    "        101823,\n",
    "        244865,\n",
    "        75729,\n",
    "        168504,\n",
    "        174992,\n",
    "        172278,\n",
    "        208219,\n",
    "        # 186291\n",
    "    ],\n",
    "    \"weights\": [\n",
    "        9.81,\n",
    "        3.9630446434020996,\n",
    "        2.7334697246551514,\n",
    "        1.8760185241699219,\n",
    "        1.6498314142227173,\n",
    "        1.4435627460479736,\n",
    "        1.056015968322754,\n",
    "        0.8861362934112549,\n",
    "        1.8535783290863037,\n",
    "        1.7445926666259766,\n",
    "        # 0.9542834758758545,\n",
    "    ],\n",
    "}\n",
    "ar_model.reset_conv()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    \"./data/ProntoQA-5hop_0shot_random_10k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = f\"./{experiment_dir}/\"\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model, test, do_al=True, al_steering_factor=0.7, do_base=True, num_samples=2000\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_5hop_AL_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_5hop_AL_detection_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_5hop_LLM_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 & 1 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 3 Hop\n",
    "with open(\n",
    "    \"./data/ProntoQA-3hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n",
    "\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model, test, do_al=True, do_base=True, num_samples=2000, al_steering_factor=0.7\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_detection_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_LLM_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "## 1 Hop\n",
    "with open(\n",
    "    \"./data/ProntoQA-1hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model, test, do_al=True, do_base=True, num_samples=2000, al_steering_factor=0.7\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_detection_Llama31-8B-Instruct.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_LLM_reasoning_Llama31-8B-Instruct.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gemma 9B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ar_model.cleanup()\n",
    "ar_model.cleanup()\n",
    "concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    \"wumpus\",\n",
    "    \"yumpus\",\n",
    "    \"zumpus\",\n",
    "    \"rompus\",\n",
    "    \"numpus\",\n",
    "    \"vumpus\",\n",
    "    \"impus\",\n",
    "    \"shumpus\",\n",
    "    \"terpus\",\n",
    "    \"sterpus\",\n",
    "    \"orpus\",\n",
    "    \"lempus\",\n",
    "    \"umpus\",\n",
    "    \"grimpus\",\n",
    "    \"gorpus\",\n",
    "    \"rimpus\",\n",
    "    \"lorpus\",\n",
    "    \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    \"dumpus\",\n",
    "    \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "\n",
    "gemma_2_9b = {\n",
    "    \"model_name\": \"google/gemma-2-9b\",\n",
    "    \"sae_name\": \"gemma-scope-9b-pt-res\",\n",
    "    \"layer\": 30,\n",
    "    \"hookpoint\": \"layer_30/width_131k/average_l0_170\",\n",
    "}\n",
    "\n",
    "gemma_1 = {\n",
    "    \"model_name\": \"google/gemma-2-9b\",\n",
    "    \"sae_name\": \"gemma-scope-9b-pt-res-canonical\",\n",
    "    \"hookpoint\": \"layer_20/width_131k/canonical\",\n",
    "    \"layer\": 20,\n",
    "}\n",
    "\n",
    "gemma_2 = {\n",
    "    \"model_name\": \"google/gemma-2-2b\",\n",
    "    \"sae_name\": \"gemma-scope-2b-pt-res-canonical\",\n",
    "    \"hookpoint\": \"layer_20/width_16k/canonical\",\n",
    "    \"layer\": 20,\n",
    "}\n",
    "\n",
    "gemma_3 = {\n",
    "    \"model_name\": \"google/gemma-2-2b\",\n",
    "    \"sae_name\": \"gemma-scope-2b-pt-res\",\n",
    "    \"hookpoint\": \"layer_12/width_262k/average_l0_121\",\n",
    "    \"layer\": 12,\n",
    "}\n",
    "\n",
    "experiment_dir = \"output/experiments/prontoqa/v0.8\"\n",
    "os.makedirs(experiment_dir, exist_ok=True)\n",
    "config = LogicConfig(\n",
    "    search_concept_type=\"word\",\n",
    "    search_concept_token=\"all\",\n",
    "    search_strategy=\"top_k\",\n",
    "    search_top_k_order=\"unique_only\",\n",
    "    search_top_k=10,\n",
    "    detection_top_k_output=10,\n",
    "    detection_top_k_concepts=10,\n",
    "    # detection_threshold=0,\n",
    "    detection_threshold=7,\n",
    "    steering_top_k_rule=10,\n",
    "    # search_concept_type=\"word\",\n",
    "    # search_concept_token=\"all\",\n",
    "    # search_strategy=\"top_k\",\n",
    "    # search_top_k_order=\"unique_first\",\n",
    "    # search_top_k=10,\n",
    "    # # detection_top_k_output=10,\n",
    "    # # detection_threshold=0,\n",
    "    # # detection_top_k_concepts=4,\n",
    "    # steering_factor=0.5,\n",
    "    # steering_top_k_rule=1,\n",
    "    # steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    # steering_norm=2,\n",
    "    # steering_methodology=\"mean_shift\",\n",
    "    # reasoner_rules_checking=\"complex\",\n",
    ")\n",
    "ar_model = ActivationReasoning(\n",
    "    rules={},\n",
    "    concepts=concepts,\n",
    "    config=config,\n",
    "    cache_dir=f\"{experiment_dir}/sae_latents\",\n",
    "    **gemma_1,\n",
    "    verbose=True,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\n",
    "    \"./data/ProntoQA-5hop_0shot_random_10k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "train = {\n",
    "    \"questions+queries+answers\": [],\n",
    "    \"queries+answers\": [],\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "\n",
    "\n",
    "for item in list(ds.items())[:8000]:\n",
    "    train[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    train[\"questions+queries+answers\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \"\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer: \"\n",
    "        + item[1][\"test_example\"][\"answer\"]\n",
    "    )\n",
    "\n",
    "    train[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    train[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    train[\"queries\"].append(\n",
    "        \"Answer True or False: \" + item[1][\"test_example\"][\"query\"] + \"\\nAnswer:\"\n",
    "    )\n",
    "    train[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "    train[\"queries+answers\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer: \"\n",
    "        + item[1][\"test_example\"][\"answer\"]\n",
    "    )\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ar_model.search(\n",
    "#     train[\"questions+queries+answers\"][:500], reset_cache=True, batch_size=16\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ar_model.concepts = [\n",
    "    \"true\",\n",
    "    \"false\",\n",
    "    \"not\",\n",
    "    \"is\",\n",
    "    # \"wumpus\",\n",
    "    # \"yumpus\",\n",
    "    # \"zumpus\",\n",
    "    # \"rompus\",\n",
    "    # \"numpus\",\n",
    "    # \"vumpus\",\n",
    "    # \"impus\",\n",
    "    # \"shumpus\",\n",
    "    # # \"terpus\",\n",
    "    # \"sterpus\",\n",
    "    # # \"orpus\",\n",
    "    # \"lempus\",\n",
    "    # # \"umpus\",\n",
    "    # \"grimpus\",\n",
    "    # \"gorpus\",\n",
    "    # # \"rimpus\",\n",
    "    # \"lorpus\",\n",
    "    # \"brimpus\",\n",
    "    \"max\",\n",
    "    \"orange\",\n",
    "    \"large\",\n",
    "    \"wooden\",\n",
    "    \"opaque\",\n",
    "    \"dull\",\n",
    "    \"spicy\",\n",
    "    \"hot\",\n",
    "    \"mean\",\n",
    "    \"discordant\",\n",
    "    \"moderate\",\n",
    "    \"windy\",\n",
    "    # \"dumpus\",\n",
    "    # \"jompus\",\n",
    "    \"fae\",\n",
    "    \"feisty\",\n",
    "    \"sweet\",\n",
    "    \"fruity\",\n",
    "    \"temperate\",\n",
    "    \"angry\",\n",
    "    \"slow\",\n",
    "    \"overcast\",\n",
    "    # \"tumpus\",\n",
    "    \"stella\",\n",
    "    \"liquid\",\n",
    "    \"happy\",\n",
    "    \"bright\",\n",
    "    \"sour\",\n",
    "    \"earthy\",\n",
    "    \"melodic\",\n",
    "    \"alex\",\n",
    "    \"blue\",\n",
    "    \"small\",\n",
    "    \"nervous\",\n",
    "    \"floral\",\n",
    "    \"amenable\",\n",
    "    \"rainy\",\n",
    "    \"rex\",\n",
    "    \"red\",\n",
    "    \"cold\",\n",
    "    \"aggressive\",\n",
    "    \"muffled\",\n",
    "    \"brown\",\n",
    "    \"transparent\",\n",
    "    \"fast\",\n",
    "    \"sunny\",\n",
    "    \"sam\",\n",
    "    \"snowy\",\n",
    "    \"kind\",\n",
    "    \"loud\",\n",
    "    \"sally\",\n",
    "    \"luminous\",\n",
    "    \"bitter\",\n",
    "    \"shy\",\n",
    "    \"metallic\",\n",
    "    \"wren\",\n",
    "    \"polly\",\n",
    "]\n",
    "logic_config = LogicConfig(\n",
    "    detection_top_k_output=7,  # Maybe try 3 or 4\n",
    "    detection_top_k_concepts=2,\n",
    "    detection_threshold=14,\n",
    "    steering_factor=0.5,\n",
    "    steering_top_k_rule=1,\n",
    "    steering_weighting_function=\"uniform\",  # \"log_decay\",\"uniform\"\n",
    "    steering_norm=2,\n",
    "    steering_methodology=\"mean_shift\",\n",
    "    reasoner_rules_checking=\"complex\",\n",
    "    search_top_k_order=\"unique_only\",\n",
    ")\n",
    "ar_model.configure(config=logic_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ar_model._al_concepts.concept_dict[\"true\"] = {\n",
    "    \"indices\": [90687, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    "    \"weights\": [34.66799545288086, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    "}\n",
    "ar_model._al_concepts.concept_dict[\"false\"] = {\n",
    "    \"indices\": [88978, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    "    \"weights\": [29.999292373657227, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
    "}\n",
    "ar_model.reset_conv()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = f\"./{experiment_dir}/\"\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model,\n",
    "    test,\n",
    "    do_al=True,\n",
    "    do_base=True,\n",
    "    num_samples=2000,\n",
    "    al_verbose=False,\n",
    "    al_steering_factor=0.4,\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"/ProntoQA_5hop_AL_reasoning_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"/ProntoQA_5hop_AL_detection_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"/ProntoQA_5hop_LLM_reasoning_Gemma2-9B.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 & 3 Hop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 3 Hop\n",
    "with open(\n",
    "    \"./data/ProntoQA-3hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n",
    "\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model,\n",
    "    test,\n",
    "    do_al=True,\n",
    "    do_base=True,\n",
    "    num_samples=2000,\n",
    "    al_steering_factor=0.4,\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_reasoning_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_3hop_AL_detection_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_3hop_LLM_reasoning_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "## 1 Hop\n",
    "with open(\n",
    "    \"./data/ProntoQA-1hop_0shot_random_2k.json\",\n",
    "    \"r\",\n",
    ") as f:\n",
    "    ds = json.load(f)\n",
    "\n",
    "test = {\n",
    "    \"questions+queries\": [],\n",
    "    \"questions\": [],\n",
    "    \"answers\": [],\n",
    "    \"queries\": [],\n",
    "    \"chains\": [],\n",
    "}\n",
    "for item in list(ds.items())[-2000:]:\n",
    "    test[\"questions+queries\"].append(\n",
    "        item[1][\"test_example\"][\"question\"] + \" \" + item[1][\"test_example\"][\"query\"]\n",
    "    )\n",
    "    test[\"questions\"].append(item[1][\"test_example\"][\"question\"])\n",
    "    test[\"answers\"].append(item[1][\"test_example\"][\"answer\"])\n",
    "    test[\"queries\"].append(\n",
    "        \"Answer True or False: \"\n",
    "        + item[1][\"test_example\"][\"query\"].replace(\"True or false:\", \"\")\n",
    "        + \"\\nAnswer:\"\n",
    "    )\n",
    "    test[\"chains\"].append(item[1][\"test_example\"][\"chain_of_thought\"])\n",
    "    # break\n",
    "\n",
    "print(\"Number of True:\", sum([1 if i == \"True\" else 0 for i in test[\"answers\"]]))\n",
    "print(\"Number of False:\", sum([1 if i == \"False\" else 0 for i in test[\"answers\"]]))\n",
    "(\n",
    "    Labels,\n",
    "    AL_output,\n",
    "    AL_output_rules,\n",
    "    AL_results,\n",
    "    AL_results_rules,\n",
    "    AL_steering_results,\n",
    "    AL_no_steering_results,\n",
    "    BASE_output,\n",
    "    BASE_results,\n",
    ") = eval(\n",
    "    ar_model,\n",
    "    test,\n",
    "    do_al=True,\n",
    "    do_base=True,\n",
    "    num_samples=2000,\n",
    "    al_steering_factor=0.4,\n",
    ")\n",
    "# AL Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [l.strip() for l in AL_output],\n",
    "    \"LLM+AL Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_reasoning_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# AL Rule Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in AL_output_rules],\n",
    "    \"AL Detection Stats\",\n",
    "    [\"true\", \"false\", \"None\"],\n",
    "    save_path + \"ProntoQA_1hop_AL_detection_Gemma2-9B.pdf\",\n",
    ")\n",
    "\n",
    "# Baseline Output\n",
    "plot_report_and_cm_with_textbox(\n",
    "    Labels,\n",
    "    [str(l).strip() for l in BASE_output],\n",
    "    \"LLM Reasoning Stats\",\n",
    "    [\"true\", \"false\"],\n",
    "    save_path + \"ProntoQA_1hop_LLM_reasoning_Gemma2-9B.pdf\",\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
