{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ar import ActivationReasoning, LogicConfig\n",
    "import pandas as pd\n",
    "from countries import tri_color_countries_subset, prompt_r2c_prefix, prompt_r2c_CoT_prefix\n",
    "import json\n",
    "from ar.utils import load_implicit_train_data, load_train_data\n",
    "\n",
    "def bold_orange(text):\n",
    "    return \"\\033[1;33m\" + str(text) + \"\\033[0m\"\n",
    "\n",
    "implicit_data = load_implicit_train_data()\n",
    "data = load_train_data()\n",
    "\n",
    "test_trains_ex, test_labels_ex = (\n",
    "    [train[0] for train in data[\"test\"]],\n",
    "    [train[1] for train in data[\"test\"]],\n",
    ")\n",
    "test_trains_im, test_labels_im = (\n",
    "    [train[0] for train in implicit_data[\"test\"]],\n",
    "    [train[1] for train in implicit_data[\"test\"]],\n",
    ")\n",
    "\n",
    "\n",
    "implicit_concept_dict = {\n",
    "    'red': ['like a ruby', 'like a tomato', 'like a stop sign', 'like a cherry', 'like a strawberry'],\n",
    "    'yellow': ['like a banana', 'like a sunflower', 'like a lemon'],\n",
    "    'white': ['like fresh snow',],\n",
    "    'orange': ['like a tangerine'],\n",
    "}\n",
    "explicit_concept_dict = {\n",
    "    \"red\": [\"red\"],\n",
    "    \"yellow\": [\"yellow\"],\n",
    "    \"white\": [\"white\"],\n",
    "    \"orange\": [\"orange\"],\n",
    "    \"blue\": [\"blue\"],\n",
    "    \"green\": [\"green\"],\n",
    "    \"gold\": [\"gold\"],\n",
    "    \"black\": [\"black\"],\n",
    "}\n",
    "\n",
    "colors = set(color for color_tuple in tri_color_countries_subset.keys() for color in color_tuple)\n",
    "countries = list(tri_color_countries_subset.values())\n",
    "\n",
    "\n",
    "experiment_dir = \"output/experiments/explicit_concepts_gemma/v3.0\"\n",
    "# experiment_dir = \"output/experiments/explicit_concepts_llama/v3.0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_config = {\n",
    "#     \"model_name\": \"meta-llama/Meta-Llama-3.1-8B\",\n",
    "#     \"sae_name\": \"EleutherAI/sae-llama-3.1-8b-64x\",\n",
    "#     \"layer\": 23,\n",
    "#     \"hookpoint\": \"layers.23\",\n",
    "# }\n",
    "model_config = {\n",
    "    \"model_name\": \"google/gemma-2-9b\",\n",
    "    \"sae_name\": \"gemma-scope-9b-pt-res-canonical\",\n",
    "    \"hookpoint\": \"layer_20/width_131k/canonical\",\n",
    "    \"layer\": 20,\n",
    "}\n",
    "# model_config = {\n",
    "#     \"model_name\": \"google/gemma-2-9b-it\",\n",
    "#     \"sae_name\": \"gemma-scope-9b-pt-res-canonical\",\n",
    "#     \"hookpoint\": \"layer_20/width_131k/canonical\",\n",
    "#     \"layer\": 20,\n",
    "# }\n",
    "\n",
    "\n",
    "rules = tri_color_countries_subset\n",
    "search_config = LogicConfig(\n",
    "    search_concept_type=\"word\",\n",
    "    search_concept_token=\"all\",\n",
    "    search_strategy=\"top_k\",\n",
    "    search_top_k_order=\"unique_first\",\n",
    "    search_top_k=10,\n",
    ")\n",
    "\n",
    "ar_model = ActivationReasoning(\n",
    "    rules=rules,\n",
    "    **model_config,\n",
    "    cache_dir=f\"{experiment_dir}/sae_latents\",\n",
    "    config=search_config,\n",
    ")\n",
    "\n",
    "t_d = [q for q, l in data['train']]\n",
    "ar_model.search(inputs=t_d, reset_cache=False, batch_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## AL Configs\n",
    "if model_config[\"model_name\"] == \"meta-llama/Meta-Llama-3.1-8B\":\n",
    "    print(\"Loading Llama3.1 8B configs\")\n",
    "    logic_config_al_im = LogicConfig(  # llama implicit\n",
    "        steering_factor=0.43,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=3,\n",
    "        detection_threshold=1.45,\n",
    "        search_top_k_order=\"unique_first\",\n",
    "        search_top_k=10,\n",
    "    )\n",
    "    logic_config_al_ex = LogicConfig(  # llama explicit\n",
    "        steering_factor=0.4,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=3,\n",
    "        detection_threshold=1.45,\n",
    "        search_top_k_order=\"unique_first\",\n",
    "        search_top_k=10,\n",
    "    )\n",
    "elif model_config[\"model_name\"] == \"google/gemma-2-9b\":\n",
    "    print(\"Loading Gemma2 9B configs\")\n",
    "    logic_config_al_im = LogicConfig(  # gemma\n",
    "        steering_factor=0.75,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=2,\n",
    "        search_top_k_order=\"unique_only\",\n",
    "    )\n",
    "    logic_config_al_ex = LogicConfig(  # gemma\n",
    "        steering_factor=0.8,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=2,\n",
    "        search_top_k_order=\"unique_only\",\n",
    "    )\n",
    "elif model_config[\"model_name\"] == \"google/gemma-2-9b-it\":\n",
    "    print(\"Loading Gemma2 9B configs\")\n",
    "    logic_config_al_im = LogicConfig(  # gemma\n",
    "        steering_factor=0.75,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=2,\n",
    "        search_top_k_order=\"unique_only\",\n",
    "    )\n",
    "    logic_config_al_ex = LogicConfig(  # gemma\n",
    "        steering_factor=0.8,\n",
    "        steering_top_k_rule=10,\n",
    "        detection_top_k_output=50,\n",
    "        detection_top_k_concepts=2,\n",
    "        search_top_k_order=\"unique_only\",\n",
    "    )\n",
    "else:\n",
    "    print(\"No AL config loaded\")\n",
    "\n",
    "## Baseline Configs\n",
    "logic_config_no_steering = LogicConfig(\n",
    "    steering_factor=0,\n",
    ")\n",
    "logic_config_sae = LogicConfig(\n",
    "    detection_top_k_output=1, \n",
    "    detection_top_k_concepts=1,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(\n",
    "    ar_model: ActivationReasoning, test_trains, logic_config, model_hyp, batch_size=20, model_conf:dict={},\n",
    "):\n",
    "    # Llama detection config\n",
    "    \n",
    "    if model_conf[\"model_name\"] == \"meta-llama/Meta-Llama-3.1-8B\":\n",
    "        color_concept_dict = {'black': {'indices': [26080, 257901, 191610, 91171, 0, 0, 0, 0, 0, 0], 'weights': [1.3942278623580933, 1.1774739027023315, 0.5122178196907043, 0.30215781927108765, 0, 0, 0, 0, 0, 0]}, 'blue': {'indices': [251272, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [3.810839891433716, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 'gold': {'indices': [37152, 24536, 1681, 0, 0, 0, 0, 0, 0, 0], 'weights': [0.5879451632499695, 0.27811482548713684, 0.2714183032512665, 0, 0, 0, 0, 0, 0, 0]}, 'green': {'indices': [38115, 166587, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [2.5125248432159424, 1.5216304063796997, 0, 0, 0, 0, 0, 0, 0, 0]}, 'orange': {'indices': [250617, 91937, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [0.6042734980583191, 0.39289236068725586, 0, 0, 0, 0, 0, 0, 0, 0]}, 'red': {'indices': [213660, 92980, 96689, 0, 0, 0, 0, 0, 0, 0], 'weights': [9.03711223602295, 1.7927311658859253, 1.3311803340911865, 0, 0, 0, 0, 0, 0, 0]}, 'white': {'indices': [114310, 16107, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [6.623536586761475, 2.470682382583618, 0, 0, 0, 0, 0, 0, 0, 0]}, 'yellow': {'indices': [237080, 178076, 26092, 0, 0, 0, 0, 0, 0, 0], 'weights': [1.6243911981582642, 1.6102770566940308, 0.7696337103843689, 0, 0, 0, 0, 0, 0, 0]}}\n",
    "    elif model_conf[\"model_name\"] == \"google/gemma-2-9b\":\n",
    "        color_concept_dict = {'yellow': {'indices': [55891, 39669, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [10.290057182312012, 4.201287746429443, 0,0,0,0, 0.0, 0.0, 0.0, 0.0, 0]}}\n",
    "    else:\n",
    "        color_concept_dict = {}\n",
    "        \n",
    "    ar_model.configure(config=logic_config)\n",
    "    for color in color_concept_dict:\n",
    "        ar_model._al_concepts.concept_dict[color] = color_concept_dict[color]\n",
    "        \n",
    "    al_outs, al_meta = ar_model.generate(\n",
    "        test_trains,\n",
    "        model_hyp=model_hyp,\n",
    "        verbose=False,\n",
    "        return_meta_data=True,\n",
    "        batch_size=batch_size,\n",
    "    )\n",
    "    return al_outs, al_meta\n",
    "\n",
    "def detect(\n",
    "    ar_model: ActivationReasoning, test_trains, logic_config, batch_size=10, model_conf:dict={},\n",
    "):\n",
    "    \n",
    "    if model_conf[\"model_name\"] == \"meta-llama/Meta-Llama-3.1-8B\":\n",
    "        color_concept_dict = {'black': {'indices': [26080, 257901, 191610, 91171, 0, 0, 0, 0, 0, 0], 'weights': [1.3942278623580933, 1.1774739027023315, 0.5122178196907043, 0.30215781927108765, 0, 0, 0, 0, 0, 0]}, 'blue': {'indices': [251272, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [3.810839891433716, 0, 0, 0, 0, 0, 0, 0, 0, 0]}, 'gold': {'indices': [37152, 24536, 1681, 0, 0, 0, 0, 0, 0, 0], 'weights': [0.5879451632499695, 0.27811482548713684, 0.2714183032512665, 0, 0, 0, 0, 0, 0, 0]}, 'green': {'indices': [38115, 166587, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [2.5125248432159424, 1.5216304063796997, 0, 0, 0, 0, 0, 0, 0, 0]}, 'orange': {'indices': [250617, 91937, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [0.6042734980583191, 0.39289236068725586, 0, 0, 0, 0, 0, 0, 0, 0]}, 'red': {'indices': [213660, 92980, 96689, 0, 0, 0, 0, 0, 0, 0], 'weights': [9.03711223602295, 1.7927311658859253, 1.3311803340911865, 0, 0, 0, 0, 0, 0, 0]}, 'white': {'indices': [114310, 16107, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [6.623536586761475, 2.470682382583618, 0, 0, 0, 0, 0, 0, 0, 0]}, 'yellow': {'indices': [237080, 178076, 26092, 0, 0, 0, 0, 0, 0, 0], 'weights': [1.6243911981582642, 1.6102770566940308, 0.7696337103843689, 0, 0, 0, 0, 0, 0, 0]}}\n",
    "    elif model_conf[\"model_name\"] == \"google/gemma-2-9b\":\n",
    "        color_concept_dict = {'yellow': {'indices': [55891, 39669, 0, 0, 0, 0, 0, 0, 0, 0], 'weights': [10.290057182312012, 4.201287746429443, 0,0,0,0, 0.0, 0.0, 0.0, 0.0, 0]}}\n",
    "    else:\n",
    "        color_concept_dict = {}\n",
    "    \n",
    "    ar_model.configure(config=logic_config)\n",
    "    for color in color_concept_dict:\n",
    "        ar_model._al_concepts.concept_dict[color] = color_concept_dict[color]\n",
    "\n",
    "    al_outs = ar_model.detect(\n",
    "        test_trains,\n",
    "        verbose=False,\n",
    "        batch_size=batch_size,\n",
    "    )\n",
    "    return al_outs\n",
    "\n",
    "\n",
    "def reason_accs(test_labels, default_output, al_output, al_rules):\n",
    "    countries = list(set(test_labels))\n",
    "    output_accuracy = pd.DataFrame(index=countries, columns=[\"Default\", \"AL\", \"Rules\"])\n",
    "\n",
    "    for country in countries:\n",
    "        country_indices = [i for i, label in enumerate(test_labels) if label == country]\n",
    "\n",
    "        # Check if country appears in model output\n",
    "        default_correct = sum(country in default_output[i] for i in country_indices)\n",
    "        al_correct = sum(country in al_output[i] for i in country_indices)\n",
    "        rules_correct = sum(\n",
    "            country in al_rules[i] and len(al_rules[i]) == 1 for i in country_indices\n",
    "        )\n",
    "        # Calculate accuracy\n",
    "        default_acc = default_correct / len(country_indices) * 100\n",
    "        al_acc = al_correct / len(country_indices) * 100\n",
    "        rules_acc = rules_correct / len(country_indices) * 100\n",
    "\n",
    "        output_accuracy.loc[country] = [\n",
    "            f\"{default_acc:.2f}%\",\n",
    "            f\"{al_acc:.2f}%\",\n",
    "            f\"{rules_acc:.2f}%\",\n",
    "        ]\n",
    "\n",
    "    # Calculate overall accuracy\n",
    "    overall_default = (\n",
    "        sum(label in default_output[i] for i, label in enumerate(test_labels))\n",
    "        / len(test_labels)\n",
    "        * 100\n",
    "    )\n",
    "    overall_al = (\n",
    "        sum(label in al_output[i] for i, label in enumerate(test_labels))\n",
    "        / len(test_labels)\n",
    "        * 100\n",
    "    )\n",
    "\n",
    "    overall_rules = (\n",
    "        sum(\n",
    "            label in al_rules[i] and len(al_rules[i]) == 1\n",
    "            for i, label in enumerate(test_labels)\n",
    "        )\n",
    "        / len(test_labels)\n",
    "        * 100\n",
    "    )\n",
    "    output_accuracy.loc[\"Overall\"] = [\n",
    "        f\"{overall_default:.2f}%\",\n",
    "        f\"{overall_al:.2f}%\",\n",
    "        f\"{overall_rules:.2f}%\",\n",
    "    ]\n",
    "\n",
    "    return output_accuracy\n",
    "\n",
    "\n",
    "def detection_accs(test_trains, sae_detection, al_detection, concept_dict, tokenizer):\n",
    "    # Create reverse mapping from implicit to original concept\n",
    "    reverse_mapping = {}\n",
    "    for original, implicit_list in concept_dict.items():\n",
    "        for implicit in implicit_list:\n",
    "            reverse_mapping[implicit] = original\n",
    "\n",
    "    # Create a DataFrame for detection accuracy\n",
    "    concept_list = [\n",
    "        concept for concepts in concept_dict.values() for concept in concepts\n",
    "    ]\n",
    "\n",
    "    # Create position-aware detection accuracy calculation\n",
    "    detection_accuracy = pd.DataFrame(index=concept_list, columns=[\"SAE\", \"AL\"])\n",
    "\n",
    "    for concept in concept_list:\n",
    "        sample_indices = [\n",
    "            i for i, sample in enumerate(test_trains) if concept in sample\n",
    "        ]\n",
    "        if not sample_indices:\n",
    "            detection_accuracy.loc[concept] = [0, 0]\n",
    "            continue\n",
    "\n",
    "        original_concept = reverse_mapping[concept]\n",
    "        # Track correct detections (accounting for position)\n",
    "        sae_correct = 0\n",
    "        al_correct = 0\n",
    "\n",
    "        for idx in sample_indices:\n",
    "            # Tokenize the input text\n",
    "            input_text = test_trains[idx]\n",
    "            tokenized = tokenizer(\n",
    "                input_text, return_offsets_mapping=True, padding_side=\"right\"\n",
    "            )\n",
    "\n",
    "            # Find where the implicit concept appears in the input text\n",
    "            start_pos = input_text.find(concept)\n",
    "            if start_pos == -1:\n",
    "                continue\n",
    "\n",
    "            end_pos = start_pos + len(concept)\n",
    "\n",
    "            # Find which token positions correspond to this text span\n",
    "            concept_token_indices = []\n",
    "            for token_idx, (token_start, token_end) in enumerate(\n",
    "                tokenized.offset_mapping\n",
    "            ):\n",
    "                # Check for overlap between token span and concept span\n",
    "                if token_end > start_pos and token_start < end_pos:\n",
    "                    concept_token_indices.append(token_idx)\n",
    "\n",
    "            \n",
    "            # Check if the concept was detected at any of these positions\n",
    "            for detection in sae_detection[idx][\"concepts\"]:\n",
    "                detected_concept = detection\n",
    "                detected_positions = [\n",
    "                    pos\n",
    "                    for c, pos in zip(\n",
    "                        sae_detection[idx][\"concepts\"], sae_detection[idx][\"position\"]\n",
    "                    )\n",
    "                    if c == detected_concept\n",
    "                ]\n",
    "\n",
    "                if detected_concept == original_concept and any(\n",
    "                    pos in concept_token_indices for pos in detected_positions\n",
    "                ):\n",
    "                    sae_correct += 1\n",
    "                    break\n",
    "\n",
    "            for detection in al_detection[idx][\"concepts\"]:\n",
    "                detected_concept = detection\n",
    "                detected_positions = [\n",
    "                    pos\n",
    "                    for c, pos in zip(\n",
    "                        al_detection[idx][\"concepts\"], al_detection[idx][\"position\"]\n",
    "                    )\n",
    "                    if c == detected_concept\n",
    "                ]\n",
    "                # print(f\"AL detected concept '{detected_concept}' at positions {detected_positions}\")\n",
    "                if detected_concept == original_concept and any(\n",
    "                    pos in concept_token_indices for pos in detected_positions\n",
    "                ):\n",
    "                    al_correct += 1\n",
    "                    break\n",
    "\n",
    "        # Calculate accuracy with position constraints\n",
    "        sae_acc = sae_correct / len(sample_indices) * 100 if sample_indices else 0\n",
    "        al_acc = al_correct / len(sample_indices) * 100 if sample_indices else 0\n",
    "\n",
    "        detection_accuracy.loc[concept] = [f\"{sae_acc:.2f}%\", f\"{al_acc:.2f}%\"]\n",
    "        # break\n",
    "\n",
    "    al_values = detection_accuracy[\"AL\"].str.rstrip(\"%\").astype(float)\n",
    "    default_values = detection_accuracy[\"SAE\"].str.rstrip(\"%\").astype(float)\n",
    "\n",
    "    # Calculate the mean\n",
    "    mean_al = al_values.mean()\n",
    "    mean_d = default_values.mean()\n",
    "    return detection_accuracy, mean_d, mean_al\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Eval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reasoning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_hyp = {'do_sample':False, 'temperature':None, 'top_k':None, 'top_p':None, \"max_new_tokens\":5}\n",
    "\n",
    "print(bold_orange(\"Explicit Concepts AL Generation:\"))\n",
    "# al_outs, al_meta = ar_model.generate(test_trains_ex, model_hyp=model_hyp, logic_config=logic_config_al, verbose=False, return_meta_data=True, batch_size=20)\n",
    "al_outs, al_meta = generate(\n",
    "    ar_model=ar_model,\n",
    "    test_trains=test_trains_ex,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_al_ex,\n",
    "    batch_size=20,\n",
    "    model_conf=model_config,\n",
    ")\n",
    "with open(f\"{experiment_dir}/al_outs_ex.json\", 'w') as f:\n",
    "    json.dump(al_outs, f)\n",
    "al_rules = [meta[\"rules\"][0] for meta in al_meta]\n",
    "with open(f\"{experiment_dir}/al_rules_ex.json\", \"w\") as f:\n",
    "    json.dump(al_rules, f)\n",
    "    \n",
    "print(bold_orange(\"Implicit Concepts AL Generation:\"))\n",
    "al_outs,  al_meta = generate(\n",
    "    ar_model=ar_model,\n",
    "    test_trains=test_trains_im,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_al_im,\n",
    "    batch_size=20,\n",
    "    model_conf=model_config\n",
    ")\n",
    "al_rules = [meta[\"rules\"][0] for meta in al_meta]\n",
    "with open(f\"{experiment_dir}/al_rules_im.json\", \"w\") as f:\n",
    "    json.dump(al_rules, f)\n",
    "with open(f\"{experiment_dir}/al_outs_im.json\", \"w\") as f:\n",
    "    json.dump(al_outs, f)\n",
    "\n",
    "\n",
    "print(bold_orange(\"Explicit Concepts Default Generation:\"))\n",
    "default_test_trains_ex = [\n",
    "    prompt_r2c_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\",\n",
    "        \"\",\n",
    "    )\n",
    "    for q in test_trains_ex\n",
    "]\n",
    "default_outs = ar_model.generate(\n",
    "    default_test_trains_ex,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False, batch_size=20\n",
    ")\n",
    "with open(f\"{experiment_dir}/default_outs_ex.json\", \"w\") as f:\n",
    "    json.dump(default_outs, f)\n",
    "\n",
    "print(bold_orange(\"Implicit Concepts Default Generation:\"))\n",
    "default_test_trains_im = [\n",
    "    prompt_r2c_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\", \"\"\n",
    "    )\n",
    "    for q in test_trains_im\n",
    "]\n",
    "default_outs = ar_model.generate(\n",
    "    default_test_trains_im,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False, batch_size=20\n",
    ")\n",
    "with open(f\"{experiment_dir}/default_outs_im.json\", \"w\") as f:\n",
    "    json.dump(default_outs, f)\n",
    "\n",
    "# default it\n",
    "print(bold_orange(\"Explicit Concepts Default IT Generation:\"))\n",
    "default_test_trains_ex = [\n",
    "    prompt_r2c_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\",\n",
    "        \"\",\n",
    "    )\n",
    "    for q in test_trains_ex\n",
    "]\n",
    "default_it_outs = ar_model.generate(\n",
    "    default_test_trains_ex,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False, batch_size=20\n",
    ")\n",
    "with open(f\"{experiment_dir}/default_it_outs_ex.json\", \"w\") as f:\n",
    "    json.dump(default_it_outs, f)\n",
    "\n",
    "print(bold_orange(\"Implicit Concepts Default IT Generation:\"))\n",
    "default_test_trains_im = [\n",
    "    prompt_r2c_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\", \"\"\n",
    "    )\n",
    "    for q in test_trains_im\n",
    "]\n",
    "default_it_outs = ar_model.generate(\n",
    "    default_test_trains_im,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False,\n",
    "    batch_size=20,\n",
    ")\n",
    "with open(f\"{experiment_dir}/default_it_outs_im.json\", \"w\") as f:\n",
    "    json.dump(default_it_outs, f)\n",
    "\n",
    "# CoT \n",
    "model_hyp = {\n",
    "    \"do_sample\": False,\n",
    "    \"temperature\": None,\n",
    "    \"top_k\": None,\n",
    "    \"top_p\": None,\n",
    "    \"max_new_tokens\": 100,\n",
    "}\n",
    "print(bold_orange(\"Explicit Concepts Loud CoT Generation:\"))\n",
    "default_test_trains_ex = [\n",
    "    prompt_r2c_CoT_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\",\n",
    "        \"\",\n",
    "    ).replace(\n",
    "        \"Due to the train's distinctive color coding, I believe it originates from the country of\",\n",
    "        \"\",\n",
    "    )\n",
    "    for q in test_trains_ex\n",
    "]\n",
    "CoT_outs = ar_model.generate(\n",
    "    default_test_trains_ex,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False,\n",
    "    batch_size=20,\n",
    ")\n",
    "with open(f\"{experiment_dir}/CoT_outs_ex.json\", \"w\") as f:\n",
    "    json.dump(CoT_outs, f)\n",
    "\n",
    "print(bold_orange(\"Implicit Concepts Loud CoT Generation:\"))\n",
    "default_test_trains_im = [\n",
    "    prompt_r2c_CoT_prefix\n",
    "    + q.replace(\n",
    "        \"Trains are painted by their country’s flag of origin.\\nYour goal is to identify the correct country based on the train’s distinctive color coding.\\n\",\n",
    "        \"\",\n",
    "    ).replace(\n",
    "        \"Due to the train's distinctive color coding, I believe it originates from the country of\", \"\"\n",
    "    )\n",
    "    for q in test_trains_im\n",
    "]\n",
    "CoT_outs = ar_model.generate(\n",
    "    default_test_trains_im,\n",
    "    model_hyp=model_hyp,\n",
    "    logic_config=logic_config_no_steering,\n",
    "    verbose=False,\n",
    "    batch_size=20,\n",
    ")\n",
    "with open(f\"{experiment_dir}/CoT_outs_im.json\", \"w\") as f:\n",
    "    json.dump(CoT_outs, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reason_accs(test_labels, outputs):\n",
    "    \"\"\"\n",
    "    Compute per-country and overall accuracies for multiple model outputs.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    test_labels : Sequence[str]\n",
    "        True label per sample (e.g., country for each row).\n",
    "    outputs : Dict[str, Union[Sequence[Sequence[str]], Dict]]\n",
    "        Mapping from column name -> predictions OR a config dict:\n",
    "          - If value is a sequence (len == len(test_labels)), each element is an\n",
    "            iterable of predicted labels for that sample. A prediction counts as\n",
    "            correct if the true label is in that iterable.\n",
    "          - If value is a dict, it must contain:\n",
    "                {\n",
    "                  \"preds\": Sequence[Sequence[str]],\n",
    "                  \"require_singleton\": bool  # if True, correct only if preds[i]\n",
    "                                           # contains exactly one label AND it\n",
    "                                           # equals the true label (replicates\n",
    "                                           # your original Rules logic)\n",
    "                }\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        Index: countries + \"Overall\"\n",
    "        Columns: outputs' keys\n",
    "        Values: percentage strings like \"83.33%\".\n",
    "    \"\"\"\n",
    "    # Normalize outputs into a common structure: {col: (preds, require_singleton)}\n",
    "    norm = {}\n",
    "    for col, val in outputs.items():\n",
    "        if isinstance(val, dict):\n",
    "            preds = val.get(\"preds\")\n",
    "            require_singleton = bool(val.get(\"require_singleton\", False))\n",
    "        else:\n",
    "            preds = val\n",
    "            require_singleton = False\n",
    "        if len(preds) != len(test_labels):\n",
    "            raise ValueError(\n",
    "                f\"Length mismatch for '{col}': \"\n",
    "                f\"{len(preds)} preds vs {len(test_labels)} labels.\"\n",
    "            )\n",
    "        norm[col] = (preds, require_singleton)\n",
    "\n",
    "    countries = sorted(set(test_labels))\n",
    "    df = pd.DataFrame(index=countries, columns=list(norm.keys()))\n",
    "\n",
    "    # Helper to determine correctness for a single sample\n",
    "    def _is_correct(true_label, pred_seq, require_singleton):\n",
    "        # pred_seq should be an iterable of labels; treat str as a singleton prediction\n",
    "        if isinstance(pred_seq, str):\n",
    "            return  true_label in pred_seq\n",
    "        try:\n",
    "            in_pred = true_label in pred_seq\n",
    "            if require_singleton:\n",
    "                try:\n",
    "                    return in_pred and (len(pred_seq) == 1)\n",
    "                except TypeError:\n",
    "                    # If pred_seq has no len() (e.g., generator), coerce to list\n",
    "                    pred_seq = list(pred_seq)\n",
    "                    return (true_label in pred_seq) and (len(pred_seq) == 1)\n",
    "            return in_pred\n",
    "        except TypeError:\n",
    "            # Non-iterable prediction; fall back to equality\n",
    "            return pred_seq == true_label\n",
    "\n",
    "    # Per-country accuracies\n",
    "    for country in countries:\n",
    "        indices = [i for i, y in enumerate(test_labels) if y == country]\n",
    "        n = len(indices)\n",
    "        for col, (preds, require_singleton) in norm.items():\n",
    "            correct = sum(\n",
    "                _is_correct(country, preds[i], require_singleton) for i in indices\n",
    "            )\n",
    "            acc = correct / n * 100 if n else 0.0\n",
    "            df.loc[country, col] = f\"{acc:.2f}%\"\n",
    "\n",
    "    # Overall accuracies\n",
    "    overall = {}\n",
    "    N = len(test_labels)\n",
    "    for col, (preds, require_singleton) in norm.items():\n",
    "        correct = sum(\n",
    "            _is_correct(true, preds[i], require_singleton)\n",
    "            for i, true in enumerate(test_labels)\n",
    "        )\n",
    "        overall[col] = f\"{(correct / N * 100):.2f}%\"\n",
    "    df.loc[\"Overall\"] = pd.Series(overall)\n",
    "\n",
    "    return df\n",
    "\n",
    "# # No feature extraction, take features from llama v1.9 and gemma v2.0\n",
    "experiment_dir = \"output/experiments/explicit_concepts_gemma/v3.0\"\n",
    "# # experiment_dir = \"output/experiments/explicit_concepts_llama/v3.0\"\n",
    "\n",
    "# # print the data\n",
    "print(bold_orange(\"Explicit Concepts Results:\"))\n",
    "with open(f\"{experiment_dir}/al_outs_ex.json\", \"r\") as f:\n",
    "    al_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/default_outs_ex.json\", \"r\") as f:\n",
    "    default_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/default_it_outs_ex.json\", \"r\") as f:\n",
    "    default_it_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/CoT_outs_ex.json\", \"r\") as f:\n",
    "    CoT_outs = json.load(f)\n",
    "# with open(f\"{experiment_dir}/default_large_it_outs_ex.json\", \"r\") as f:\n",
    "#     default_large_it_outs = json.load(f)\n",
    "# with open(f\"{experiment_dir}/CoT_large_outs_ex.json\", \"r\") as f:\n",
    "#     CoT_large_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/al_rules_ex.json\", \"r\") as f:\n",
    "    al_rules = json.load(f)\n",
    "# output_accuracy = reason_accs(test_labels_ex, default_outs, al_outs, al_rules)\n",
    "output_accuracy = reason_accs(\n",
    "    test_labels_ex,\n",
    "    {\n",
    "        \"Default\": default_outs,\n",
    "        \"Default IT\": default_it_outs,\n",
    "        # \"Default large IT\": default_large_it_outs,\n",
    "        \"CoT\": CoT_outs,\n",
    "        # \"CoT large\": CoT_large_outs,\n",
    "        \"AL\": al_outs,\n",
    "        \"Rules\": {\"preds\": al_rules, \"require_singleton\": True},\n",
    "    },\n",
    ")\n",
    "# print(al_rules)\n",
    "print(output_accuracy)\n",
    "\n",
    "print(bold_orange(\"\\nImplicit Concepts Results:\"))\n",
    "with open(f\"{experiment_dir}/al_outs_im.json\", \"r\") as f:\n",
    "    al_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/default_outs_im.json\", \"r\") as f:\n",
    "    default_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/default_it_outs_im.json\", \"r\") as f:\n",
    "    default_it_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/CoT_outs_im.json\", \"r\") as f:\n",
    "    CoT_outs = json.load(f)\n",
    "# with open(f\"{experiment_dir}/default_large_it_outs_im.json\", \"r\") as f:\n",
    "#     default_large_it_outs = json.load(f)\n",
    "# with open(f\"{experiment_dir}/CoT_large_outs_im.json\", \"r\") as f:\n",
    "#     CoT_large_outs = json.load(f)\n",
    "with open(f\"{experiment_dir}/al_rules_im.json\", \"r\") as f:\n",
    "    al_rules = json.load(f)\n",
    "output_accuracy = reason_accs(\n",
    "    test_labels_ex,\n",
    "    {\n",
    "        \"Default\": default_outs,\n",
    "        \"Default IT\": default_it_outs,\n",
    "        # \"Default large IT\": default_large_it_outs,\n",
    "        \"CoT\": CoT_outs,\n",
    "        # \"CoT large\": CoT_large_outs,\n",
    "        \"AL\": al_outs,\n",
    "        \"Rules\": {\"preds\": al_rules, \"require_singleton\": True},\n",
    "    },\n",
    ")\n",
    "# print(al_rules)\n",
    "print(output_accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model_hyp = {'do_sample':False, 'temperature':None, 'top_k':None, 'top_p':None}\n",
    "\n",
    "\n",
    "print(bold_orange(\"Explicit Concepts:\"))\n",
    "# al_detection = ar_model.detect(\n",
    "#     test_trains_ex, verbose=False, logic_config=logic_config_al_ex\n",
    "# )\n",
    "al_detection = detect(\n",
    "    ar_model, test_trains_ex, logic_config=logic_config_al_ex, model_conf=model_config)\n",
    "with open(f\"{experiment_dir}/al_detections_ex.json\", 'w') as f:\n",
    "    json.dump(al_detection, f)\n",
    "\n",
    "# sae_detection = ar_model.detect(\n",
    "#     test_trains_ex, logic_config=logic_config_sae, verbose=False\n",
    "# )\n",
    "sae_detection = detect(\n",
    "    ar_model, test_trains_ex, logic_config=logic_config_sae, model_conf=model_config)\n",
    "with open(f\"{experiment_dir}/sae_detections_ex.json\", \"w\") as f:\n",
    "    json.dump(sae_detection, f)\n",
    "    \n",
    "    \n",
    "print(bold_orange(\"Implicit Concepts:\"))\n",
    "# al_detection = ar_model.detect(\n",
    "#     test_trains_im, verbose=False, logic_config=logic_config_al_im\n",
    "# )\n",
    "al_detection = detect(\n",
    "    ar_model, test_trains_im, logic_config=logic_config_al_im, model_conf=model_config\n",
    ")\n",
    "with open(f\"{experiment_dir}/al_detections_im.json\", \"w\") as f:\n",
    "    json.dump(al_detection, f)\n",
    "\n",
    "# sae_detection = ar_model.detect(\n",
    "#     test_trains_im, logic_config=logic_config_sae, verbose=False\n",
    "# )\n",
    "sae_detection = detect(\n",
    "    ar_model, test_trains_im, logic_config=logic_config_sae, model_conf=model_config\n",
    ")\n",
    "with open(f\"{experiment_dir}/sae_detections_im.json\", \"w\") as f:\n",
    "    json.dump(sae_detection, f)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(bold_orange(\"Explicit Concepts Detection Results:\"))\n",
    "with open(f\"{experiment_dir}/al_detections_ex.json\", \"r\") as f:\n",
    "    al_detections = json.load(f)\n",
    "\n",
    "with open(f\"{experiment_dir}/sae_detections_ex.json\", \"r\") as f:\n",
    "    sae_detections = json.load(f)\n",
    "detection_accuracy, mean_d, mean_al = detection_accs(test_trains_ex, sae_detections, al_detections, explicit_concept_dict, ar_model.tokenizer)\n",
    "print(detection_accuracy)\n",
    "print(f\"Mean AL detection accuracy: {mean_al:.2f}%\")\n",
    "print(f\"Mean SAE detection accuracy: {mean_d:.2f}%\")\n",
    "\n",
    "\n",
    "print(bold_orange(\"\\nImplicit Concepts Detection Results:\"))\n",
    "with open(f\"{experiment_dir}/al_detections_im.json\", \"r\") as f:\n",
    "    al_detections = json.load(f)\n",
    "\n",
    "with open(f\"{experiment_dir}/sae_detections_im.json\", \"r\") as f:\n",
    "    sae_detections = json.load(f)\n",
    "detection_accuracy, mean_d, mean_al = detection_accs(\n",
    "    test_trains_im,\n",
    "    sae_detections,\n",
    "    al_detections,\n",
    "    implicit_concept_dict,\n",
    "    # {'red': ['like a stop sign',]},\n",
    "    ar_model.tokenizer,\n",
    ")\n",
    "print(detection_accuracy)\n",
    "print(f\"Mean AL detection accuracy: {mean_al:.2f}%\")\n",
    "print(f\"Mean SAE detection accuracy: {mean_d:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sanity Checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Samples containing explicit concepts:\")\n",
    "print(f\"Total test samples: {len(test_trains_ex)}\")\n",
    "sample_with_concept = 0\n",
    "concept_list = [\n",
    "    concept for concepts in explicit_concept_dict.values() for concept in concepts\n",
    "]\n",
    "# For each implicit concept, check if any test samples contain it\n",
    "for concept in concept_list:\n",
    "    sample_indices = [i for i, sample in enumerate(test_trains_ex) if concept in sample]\n",
    "    print(f\"Concept: {concept}, Found in {len(sample_indices)} samples\")\n",
    "    sample_with_concept += len(sample_indices)\n",
    "\n",
    "print(f\"Total samples with explicit concepts: {sample_with_concept}\")\n",
    "\n",
    "print(\"\\nSamples containing implicit concepts:\")\n",
    "print(f\"Total test samples: {len(test_trains_ex)}\")\n",
    "sample_with_concept = 0\n",
    "concept_list = [\n",
    "    concept for concepts in implicit_concept_dict.values() for concept in concepts\n",
    "]\n",
    "# For each implicit concept, check if any test samples contain it\n",
    "for concept in concept_list:\n",
    "    sample_indices = [i for i, sample in enumerate(test_trains_im) if concept in sample]\n",
    "    print(f\"Concept: {concept}, Found in {len(sample_indices)} samples\")\n",
    "    sample_with_concept += len(sample_indices)\n",
    "\n",
    "print(f\"Total samples with implicit concepts: {sample_with_concept}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
