{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cd .. && pip install -e ./../nnpatch ./../pycolors && pip install -U transformers kaleido && pip install circuitsvis python-dotenv --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U kaleido\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sudo apt install fonts-noto-color-emoji cm-super fonts-cmu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from nnsight import NNsight\n",
    "import torch\n",
    "import os\n",
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "from nnsight import NNsight\n",
    "\n",
    "from analysis.circuit_utils.visualisation import *\n",
    "from analysis.circuit_utils.decoding import get_decoding_args, get_data, generate_title, get_plot_prior_patch, get_plot_context_patch, get_plot_weightcp_patch, get_plot_weightpc_patch\n",
    "\n",
    "from main import load_model_and_tokenizer\n",
    "\n",
    "import pandas as pd\n",
    "import os\n",
    "from plotly.colors import n_colors\n",
    "from analysis.circuit_utils.visualisation import plot_das_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIGS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pycolors import TailwindColorPalette, to_rgb\n",
    "from analysis.circuit_utils.visualisation import format_label\n",
    "import json\n",
    "from analysis.circuit_utils.visualisation import plot_das_results\n",
    "from model_utils.utils import construct_test_results_dir, construct_paths_and_dataset_kwargs\n",
    "\n",
    "from generate_run_script import CONFIGS\n",
    "\n",
    "DATASET2SUBSPLIT = {\n",
    "    \"BaseFakepedia\": \"nodup_relpid\",\n",
    "    \"MultihopFakepedia\": \"nodup_relpid\",\n",
    "    \"Arithmetic\": \"d2ub9\",\n",
    "}\n",
    "\n",
    "def get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\"):\n",
    "    if steering:\n",
    "        eval_cwf=\"none\" # cwf is always none for steering\n",
    "    else:\n",
    "        eval_cwf=cwf\n",
    "    MODEL_CONFIGS = CONFIGS[model]\n",
    "    _, _, _, results_dir, _, _, _, _, _ = construct_paths_and_dataset_kwargs(\n",
    "        DATASET_NAME=\"BaseFakepedia\",\n",
    "        SUBSPLIT=\"nodup_relpid\",\n",
    "        SEED=seed,\n",
    "        TRAIN_SIZE=2048,\n",
    "        MODEL_ID=MODEL_CONFIGS[\"instruct_model\"].split(\"/\")[-1] if use_instruct else MODEL_CONFIGS[\"base_model\"].split(\"/\")[-1],\n",
    "        PEFT=True,\n",
    "        BATCH_SZ=MODEL_CONFIGS[\"bs\"],\n",
    "        GRAD_ACCUM=MODEL_CONFIGS[\"ga\"],\n",
    "        NO_TRAIN=not finetuned,\n",
    "        LORA_MODULES=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
    "        LOAD_IN_4BIT=False,\n",
    "        LOAD_IN_8BIT=False,\n",
    "        CONTEXT_WEIGHT_AT_END=False,\n",
    "        CONTEXT_WEIGHT_FORMAT=cwf,\n",
    "        ANSWER_FORMAT_PROMPT_POSITION=afpp,\n",
    "        ADD_ANSWER_FORMAT_PROMPT=aafp,\n",
    "    )\n",
    "    results_dir = construct_test_results_dir(\n",
    "        results_dir,\n",
    "        eval_name=dataset,\n",
    "        subsplit=DATASET2SUBSPLIT[dataset],\n",
    "        k_demonstrations=k,\n",
    "        context_weight_format=eval_cwf,\n",
    "        answer_format_prompt_position=afpp,\n",
    "        add_answer_format_prompt=aafp,\n",
    "        do_steering=steering,\n",
    "        steering_prior_value=float(MODEL_CONFIGS[\"prior_value\"]),\n",
    "        steering_context_value=float(MODEL_CONFIGS[\"context_value\"]),\n",
    "        steering_layer=MODEL_CONFIGS[\"steering_layer\"],\n",
    "        in_domain_demonstrations=in_domain_demonstrations,\n",
    "    )\n",
    "    return results_dir\n",
    "    \n",
    "def get_results(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\"): \n",
    "    results_dir = get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed, in_domain_demonstrations, aafp, afpp)\n",
    "    results_file = os.path.join(results_dir, \"metrics.json\")\n",
    "    print(results_file)\n",
    "    with open(results_file, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "    return data\n",
    "\n",
    "def get_features(model, use_instruct, finetuned, dataset, cwf, k, steering, seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\"):\n",
    "    MODEL_CONFIGS = CONFIGS[model]\n",
    "    results_dir = get_results_dir(model, use_instruct, finetuned, dataset, cwf, k, steering, seed, in_domain_demonstrations, aafp, afpp)\n",
    "    features_file = os.path.join(results_dir, f\"features_{MODEL_CONFIGS['steering_layer']}.pt\")\n",
    "    features = torch.load(features_file).cpu().numpy()\n",
    "    return features\n",
    "\n",
    "def get_configurations(model, seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\"):\n",
    "    configurations = [\n",
    "        (\"INSTRUCT FT INSTRUCTION\", (model, True, True, \"BaseFakepedia\", \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FT FLOAT\", (model, True, True, \"BaseFakepedia\", \"float\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FT INSTRUCTION\", (model, False, True, \"BaseFakepedia\", \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FT FLOAT\", (model, False, True, \"BaseFakepedia\", \"float\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FS INSTRUCTION\", (model, True, False, \"BaseFakepedia\", \"instruction\", 10, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FS FLOAT\", (model, True, False, \"BaseFakepedia\", \"float\", 10, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS INSTRUCTION\", (model, False, False, \"BaseFakepedia\", \"instruction\", 10, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS FLOAT\", (model, False, False, \"BaseFakepedia\", \"float\", 10, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS INSTRUCTION\", (model, True, False, \"BaseFakepedia\", \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS FLOAT\", (model, True, False, \"BaseFakepedia\", \"float\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE ZS INSTRUCTION\", (model, False, False, \"BaseFakepedia\", \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE ZS FLOAT\", (model, False, False, \"BaseFakepedia\", \"float\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "    ]\n",
    "    steering_configurations = [\n",
    "        (\"INSTRUCT FT INSTRUCTION\", (model, True, True, \"BaseFakepedia\", \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FT FLOAT\", (model, True, True, \"BaseFakepedia\", \"float\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FT INSTRUCTION\", (model, False, True, \"BaseFakepedia\", \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FT FLOAT\", (model, False, True, \"BaseFakepedia\", \"float\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FS INSTRUCTION\", (model, True, False, \"BaseFakepedia\", \"instruction\", 10, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT FS FLOAT\", (model, True, False, \"BaseFakepedia\", \"float\", 10, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS INSTRUCTION\", (model, False, False, \"BaseFakepedia\", \"instruction\", 10, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS FLOAT\", (model, False, False, \"BaseFakepedia\", \"float\", 10, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS INSTRUCTION\", (model, True, False, \"BaseFakepedia\", \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS FLOAT\", (model, True, False, \"BaseFakepedia\", \"float\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE ZS INSTRUCTION\", (model, False, False, \"BaseFakepedia\", \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE ZS FLOAT\", (model, False, False, \"BaseFakepedia\", \"float\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "    ]\n",
    "    return configurations, steering_configurations\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Performance Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "from pycolors import TailwindColorPalette, to_rgb\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "\n",
    "MODEL_NAMES = {\n",
    "    \"Meta-Llama-3.1-8B-Instruct\": \"<b>Llama</b> 3.1 8B\",\n",
    "    \"gemma-2-9b-it\": \"<b>Gemma</b> 2 9B\",\n",
    "    \"Mistral-7B-Instruct-v0.3\": \"<b>Mistral</b> 7B\",\n",
    "}\n",
    "\n",
    "MODEL_NAMES = {\n",
    "    \"llama\": \"🦙\",\n",
    "    \"gemma\": \"💎\",\n",
    "    \"mistral\": \"🌬️\",\n",
    "}\n",
    "\n",
    "column_map = {\n",
    "    \"baseline\": \"Baseline: Intent Instruction\",\n",
    "    \"no_instruction\": \"Steering: No Instruction\",\n",
    "}\n",
    "metric_map = {\n",
    "    \"acc\": \"Accuracy\",\n",
    "    \"pair_acc\": \"PairAcc\"\n",
    "}\n",
    "def get_str_colors(color):\n",
    "    return f\"rgb({','.join(map(str, color))})\"\n",
    "\n",
    "colors = {\n",
    "    'no_instruction': get_str_colors(to_rgb(COLORS.get_shade(1, 500))),\n",
    "    'baseline': get_str_colors(to_rgb(COLORS.get_shade(4, 300))),\n",
    "}\n",
    "\n",
    "def get_ds_configurations(model, dataset=\"BaseFakepedia\", seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\"):\n",
    "    return [\n",
    "        (\"INSTRUCT FT INSTRUCTION\", (model, True, True, dataset, \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS INSTRUCTION\", (model, False, False, dataset, \"instruction\", 10, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS INSTRUCTION\", (model, True, False, dataset, \"instruction\", 0, False, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "    ], [\n",
    "        (\"INSTRUCT FT INSTRUCTION\", (model, True, True, dataset, \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"BASE FS INSTRUCTION\", (model, False, False, dataset, \"instruction\", 10, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "        (\"INSTRUCT ZS INSTRUCTION\", (model, True, False, dataset, \"instruction\", 0, True, seed, in_domain_demonstrations, aafp, afpp)),\n",
    "    ]\n",
    "\n",
    "DATASETS = [\"BaseFakepedia\", \"MultihopFakepedia\", \"Arithmetic\"]\n",
    "\n",
    "def plot_das_results_by_model(dataset, models=[\"gemma\", \"llama\", \"mistral\"], seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\", metric='accuracy', no_y_axis=False, COLORS=TailwindColorPalette(), legend_loc=\"top\"):\n",
    "    fig = go.Figure()\n",
    "    \n",
    "    # Create hierarchical x-axis labels\n",
    "    x_labels = defaultdict(lambda: [[], []])\n",
    "    y_values_dict = {column: [] for column in [\"baseline\", \"no_instruction\"]}\n",
    "    \n",
    "    for model in models:\n",
    "        model_name = MODEL_NAMES[model]\n",
    "        configurations, steering_configurations = get_ds_configurations(model, dataset, seed, in_domain_demonstrations, aafp, afpp)\n",
    "        for baseline, steering in zip(configurations, steering_configurations):\n",
    "            baseline_value = get_results(*baseline[1])[metric]\n",
    "            steering_value = get_results(*steering[1])[metric]\n",
    "            y_values_dict[\"baseline\"].append(baseline_value)\n",
    "            y_values_dict[\"no_instruction\"].append(steering_value)\n",
    "            x_labels[\"baseline\"][0].append(format_label(to_standart_label(baseline[0])))\n",
    "            x_labels[\"no_instruction\"][0].append(format_label(to_standart_label(steering[0])))\n",
    "            x_labels[\"baseline\"][1].append(model_name)\n",
    "            x_labels[\"no_instruction\"][1].append(model_name)\n",
    "    \n",
    "    ORDER = [\"baseline\", \"no_instruction\"]\n",
    "    # Add traces for each column\n",
    "    for column in ORDER:\n",
    "        fig.add_trace(go.Bar(\n",
    "            name=column_map[column],\n",
    "            x=x_labels[column],\n",
    "            y=y_values_dict[column],\n",
    "            text=[f'{v:.2f}' if v > 0.05 else '0' for v in y_values_dict[column]],\n",
    "            textposition=['auto' if v != 0 else 'outside' for v in y_values_dict[column]],\n",
    "            textfont=dict(size=20),\n",
    "            marker_color=colors[column]\n",
    "        ))\n",
    "\n",
    "    fig.update_layout(\n",
    "        yaxis_title=metric_map[metric],\n",
    "        # legend_title='Evaluation Setting',\n",
    "        font=dict(size=20),\n",
    "    )\n",
    "\n",
    "    if legend_loc == \"top\":\n",
    "        x = 0.97\n",
    "        y = 0.97\n",
    "    elif legend_loc == \"middle\":\n",
    "        x = 0.97\n",
    "        y = 0.4\n",
    "    else:\n",
    "        x = 0.97\n",
    "        y = 0.27\n",
    "\n",
    "    if legend_loc != None:  \n",
    "        fig.update_layout(legend=dict(\n",
    "            groupclick=\"toggleitem\",\n",
    "            yanchor=\"top\",\n",
    "            y=y,\n",
    "            xanchor=\"right\",\n",
    "            x=x,\n",
    "            # traceorder=\"grouped\",\n",
    "            orientation=\"v\",\n",
    "        ))\n",
    "    else:\n",
    "        fig.update_layout(showlegend=False)\n",
    "\n",
    "\n",
    "    width = 750\n",
    "    fig.update_yaxes(range=[0.0, 1.0])\n",
    "    if no_y_axis:\n",
    "        width = 720\n",
    "        fig.update_layout(yaxis=dict(visible=True, showticklabels=False, ticks=\"\", title=\"\"))\n",
    "    \n",
    "    fig.update_xaxes(tickfont=dict(size=29))\n",
    "\n",
    "    # set font size\n",
    "    fig.update_layout(\n",
    "        uniformtext_minsize=25,\n",
    "        uniformtext_mode='show',\n",
    "    )\n",
    "    # set width\n",
    "    fig.update_layout(width=width, height=400, margin=dict(l=0, r=0, t=0, b=0), font_family=\"Computer Modern\")\n",
    "    return fig\n",
    "\n",
    "DATASET_NAMES = {\n",
    "    \"BaseFakepedia\": \"BF\",\n",
    "    \"MultihopFakepedia\": \"MH\",\n",
    "    \"Arithmetic\": \"AR\",\n",
    "}\n",
    "\n",
    "def plot_das_results_by_dataset(model, datasets=[\"BaseFakepedia\", \"MultihopFakepedia\", \"Arithmetic\"], metric='accuracy', no_y_axis=False, seed=3, in_domain_demonstrations=False, aafp=False, afpp=\"end\", COLORS=TailwindColorPalette(), legend_loc=\"top\"):\n",
    "    fig = go.Figure()\n",
    "    # Create hierarchical x-axis labels\n",
    "    x_labels = defaultdict(lambda: [[], []])\n",
    "    y_values_dict = {column: [] for column in colors.keys()}\n",
    "    \n",
    "    for dataset in datasets:\n",
    "        configurations, steering_configurations = get_ds_configurations(model, dataset, seed, in_domain_demonstrations, aafp, afpp)\n",
    "        dataset_name = DATASET_NAMES[dataset]\n",
    "        for baseline, steering in zip(configurations, steering_configurations):\n",
    "            baseline_value = get_results(*baseline[1])[metric]\n",
    "            steering_value = get_results(*steering[1])[metric]\n",
    "            y_values_dict[\"baseline\"].append(baseline_value)\n",
    "            y_values_dict[\"no_instruction\"].append(steering_value)\n",
    "            x_labels[\"baseline\"][0].append(format_label(to_standart_label(baseline[0])))\n",
    "            x_labels[\"no_instruction\"][0].append(format_label(to_standart_label(steering[0])))\n",
    "            x_labels[\"baseline\"][1].append(dataset_name)\n",
    "            x_labels[\"no_instruction\"][1].append(dataset_name)\n",
    "        # for column in columns:\n",
    "        #     for key in configurations:\n",
    "        #         if key in data[model][dataset].keys() and column in data[model][dataset][key].columns:\n",
    "        #             value = float(data[model][dataset][key].loc[data[model][dataset][key]['Unnamed: 0'] == metric, column].values[0])\n",
    "        #             y_values_dict[column].append(value)\n",
    "        #             x_labels[column][0].append(format_label(to_standart_label(key)))\n",
    "        #             x_labels[column][1].append(DATASET_NAMES[dataset])\n",
    "    \n",
    "    # Add traces for each column\n",
    "    def get_num(num):\n",
    "        if num > 0.1:\n",
    "            return f'{num:.2f}'\n",
    "        elif num > 0.05:\n",
    "            return f'0.1'\n",
    "        else:\n",
    "            return '0'\n",
    "\n",
    "    ORDER = [\"baseline\", \"no_instruction\"]\n",
    "\n",
    "    for column in ORDER:\n",
    "        fig.add_trace(go.Bar(\n",
    "            name=column_map[column],\n",
    "            x=x_labels[column],\n",
    "            y=y_values_dict[column],\n",
    "            text=[get_num(v) for v in y_values_dict[column]],\n",
    "            textposition=['auto' if v > 0.05 else 'outside' for v in y_values_dict[column]],\n",
    "            textfont=dict(size=20),\n",
    "            marker_color=colors[column]\n",
    "        ))\n",
    "\n",
    "    fig.update_layout(\n",
    "        yaxis_title=metric_map[metric],\n",
    "        font=dict(size=20),\n",
    "    )\n",
    "\n",
    "    if legend_loc == \"top\":\n",
    "        x = 0.97\n",
    "        y = 0.97\n",
    "    elif legend_loc == \"middle\":\n",
    "        x = 0.97\n",
    "        y = 0.4\n",
    "    else:\n",
    "        x = 0.97\n",
    "        y = 0.27\n",
    "\n",
    "    if legend_loc != None:  \n",
    "        fig.update_layout(legend=dict(\n",
    "            groupclick=\"toggleitem\",\n",
    "            yanchor=\"top\",\n",
    "            y=y,\n",
    "            xanchor=\"right\",\n",
    "            x=x,\n",
    "            orientation=\"v\",\n",
    "        ))\n",
    "    else:\n",
    "        fig.update_layout(showlegend=False)\n",
    "\n",
    "    width = 760\n",
    "    fig.update_yaxes(range=[0.0, 1.0])\n",
    "    if no_y_axis:\n",
    "        width = 720\n",
    "        fig.update_layout(yaxis=dict(visible=True, showticklabels=False, ticks=\"\", title=\"\"))\n",
    "    \n",
    "    fig.update_xaxes(tickfont=dict(size=29))\n",
    "\n",
    "    # set font size\n",
    "    fig.update_layout(\n",
    "        uniformtext_minsize=25,\n",
    "        uniformtext_mode='show',\n",
    "    )\n",
    "    # set width\n",
    "    fig.update_layout(width=width, height=400, margin=dict(l=0, r=0, t=0, b=0), font_family=\"CMU Serif\")\n",
    "    return fig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_das_results_by_model(\"BaseFakepedia\", legend_loc=None, no_y_axis=False, metric='pair_acc', COLORS=TailwindColorPalette())\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=760)\n",
    "fig.write_image(\"plots/pair_acc_BF_models.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_das_results_by_dataset(\"llama\", legend_loc=\"top\", datasets=[\"BaseFakepedia\", \"MultihopFakepedia\", \"Arithmetic\"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=760)\n",
    "fig.write_image(\"plots/generalization_llama_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_das_results_by_dataset(\"mistral\", legend_loc=\"top\", datasets=[\"BaseFakepedia\", \"MultihopFakepedia\", \"Arithmetic\"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=760)\n",
    "fig.write_image(\"plots/generalization_mistral_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plot_das_results_by_dataset(\"gemma\", legend_loc=\"top\", datasets=[\"BaseFakepedia\", \"MultihopFakepedia\", \"Arithmetic\"],  no_y_axis=False, metric='pair_acc',COLORS=TailwindColorPalette())\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=760)\n",
    "fig.write_image(\"plots/generalization_gemma_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import pandas as pd\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "# we just need one of the results for the labels\n",
    "df = pd.read_csv(\"data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-NT/results/BaseFakepedia-sp_nodup_relpid-k0_OOD-cwf_instruction/test.csv\")\n",
    "labels = df[\"weight_context\"].to_numpy() == 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"llama\", in_domain_demonstrations=False)\n",
    "configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from plotly.colors import n_colors\n",
    "from plotly.subplots import make_subplots\n",
    "from pycolors import TailwindColorPalette, to_rgb\n",
    "import numpy as np\n",
    "\n",
    "def plot_feature_distribution(configs, labels, COLORS=TailwindColorPalette()):\n",
    "    names = [name for name, _ in configs]\n",
    "    print(names)\n",
    "    data = [get_features(*values).flatten() for _, values in configs]\n",
    "    accuracies = {name: get_results(*values)[\"pair_acc\"] for name, values in configs}\n",
    "\n",
    "    colors_accs =[to_rgb(COLORS.get_shade(4, 300)) for _ in range(len(names))]\n",
    "    colors_zero = n_colors(to_rgb(COLORS.get_shade(1, 500)), to_rgb(COLORS.get_shade(1, 600)), len(names))\n",
    "    colors_one = n_colors(to_rgb(COLORS.get_shade(7, 500)), to_rgb(COLORS.get_shade(7, 600)), len(names))\n",
    "\n",
    "    # Define colors for each column\n",
    "    colors = {\n",
    "        'one': [f\"rgb({','.join(map(lambda x: str(int(x)), c))})\" for c in colors_one],\n",
    "        'zero': [f\"rgb({','.join(map(lambda x: str(int(x)), c))})\" for c in colors_zero],\n",
    "        'accs': [f\"rgb({','.join(map(lambda x: str(int(x)), c))})\" for c in colors_accs],\n",
    "    }\n",
    "    fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing = 0.02, subplot_titles=(r\"Distribution of Subspace Values\", r\"PairAcc\"))\n",
    "    add_legend = [False for _ in range(len(names))]\n",
    "    add_legend[-1] = True\n",
    "    for data_line, color_one, color_zero, name, legend in zip(data, colors['one'], colors['zero'], names, add_legend):\n",
    "        formated_name = format_label(name)\n",
    "        # if \"INSTRUCT FT Instruction\" in name:\n",
    "        #     formated_name = \"<span style='text-decoration:underline;'>\"+formated_name+\"</span>\"\n",
    "        fig.add_trace(go.Violin(x=data_line[labels], y=[formated_name]*len(data_line[labels]), side='positive', width=3, points=False, line_color=color_one,  name=\"intent = ctx\", legendgroup=name.split(\" \")[0], orientation='h', showlegend=legend, line=dict(color=\"black\"), meanline_visible=True), row=1, col=1)\n",
    "        fig.add_trace(go.Violin(x=data_line[~labels], y=[formated_name]*len(data_line[~labels]),side='positive', width=3, points=False, line_color=color_zero,  name=\"intent = prior\", legendgroup=name.split(\" \")[0], orientation='h', showlegend=legend, line=dict(color=\"black\"), meanline_visible=True), row=1, col=1)\n",
    "\n",
    "\n",
    "    # Add horizontal bar plot with same x axis on subplot col 2\n",
    "    for name, accuracy in accuracies.items():\n",
    "        fig.add_trace(go.Bar(\n",
    "            y=[format_label(name)],\n",
    "            x=[accuracy],\n",
    "            orientation='h',\n",
    "            text=[f\"{accuracy:.2f}\"],\n",
    "            textposition='auto' if accuracy > 0.1 else 'outside',\n",
    "            # textfont=dict(size=25), #if accuracy > 0.1 else dict(size=16),\n",
    "            insidetextfont=dict(size=25),\n",
    "            name=name,\n",
    "            showlegend=False,\n",
    "            marker_color=colors['accs'][names.index(name)],\n",
    "        ), row=1, col=2)\n",
    "    fig.update_layout(\n",
    "    uniformtext_minsize=25,\n",
    "    uniformtext_mode='show',\n",
    "    )\n",
    "\n",
    "    # Ensure the x-axis ranges are appropriate\n",
    "    fig.update_xaxes(range=[0, 1], row=1, col=2)\n",
    "\n",
    "    from scipy.stats.stats import pearsonr   \n",
    "\n",
    "    mean_diff = {name: np.abs(np.mean(data_line[labels]) - np.mean(data_line[~labels])) for name, data_line in zip(names, data)}\n",
    "    mean_diffs = [mean_diff[name] for name in names]\n",
    "    accs = [accuracies[name] for name in names]\n",
    "\n",
    "\n",
    "    # Add a textbox to display the Pearson correlation\n",
    "    correlation, p_value = pearsonr(mean_diffs, accs)\n",
    "\n",
    "\n",
    "    # Add custom legend using annotations\n",
    "    legend_annotations = [\n",
    "        dict(\n",
    "        x=0.99,\n",
    "        y=0.98,\n",
    "        xref=\"paper\",\n",
    "        yref=\"paper\",\n",
    "        text=f\"Pearson correlation between <br> subspace value mean difference <br> and PairAcc:<br> {correlation:.3f} (p={p_value:.3f})\",\n",
    "        showarrow=False,\n",
    "        font=dict(size=25),\n",
    "        align=\"right\",\n",
    "        bgcolor=\"rgba(255, 255, 255, 0.95)\",\n",
    "        bordercolor=\"black\",\n",
    "        borderwidth=1,\n",
    "        borderpad=4,\n",
    "        xanchor=\"right\",\n",
    "        yanchor=\"top\")\n",
    "    ]\n",
    "    fig.update_layout(annotations=legend_annotations)\n",
    "\n",
    "\n",
    "    # set x axis titles\n",
    "    fig.update_xaxes(title_text=\"Feature Value\", row=1, col=1)\n",
    "    fig.update_xaxes(title_text=\"PairAcc\", row=1, col=2)\n",
    "\n",
    "\n",
    "    fig.update_layout(\n",
    "        height=500,\n",
    "        width=1800\n",
    "    )\n",
    "\n",
    "    fig.update_layout(xaxis_showgrid=True, xaxis_zeroline=False)\n",
    "    fig.update_layout(yaxis_showgrid=True, yaxis_zeroline=False)\n",
    "    # fig.update_layout(title=\"Llama3.1 8B – Distribution of Feature R_{cp,16} and Relationship with Model Accuracy on BaseFakepedia\")\n",
    "    # larger font\n",
    "    fig.update_layout(font=dict(size=25))\n",
    "\n",
    "\n",
    "    add_title = False\n",
    "    if add_title:\n",
    "        fig.update_layout(title=\"Llama3.1 8B – Distribution of Feature R_{cp,16} and Relationship with Model PairAcc on BaseFakepedia\")\n",
    "    else:\n",
    "        fig.update_layout(margin=dict(t=0, b=0, l=0, r=0))\n",
    "\n",
    "        \n",
    "\n",
    "    # Move legend to the second subplot\n",
    "    fig.update_layout(legend=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.98,\n",
    "        xanchor=\"right\",\n",
    "        x=0.485, #48\n",
    "        orientation=\"h\",\n",
    "        bgcolor=\"rgba(255, 255, 255, 0.9)\"\n",
    "    ), legend2=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.98,\n",
    "        xanchor=\"right\",\n",
    "        x=0.995, #48\n",
    "        orientation=\"h\",\n",
    "        bgcolor=\"rgba(255, 255, 255, 0.9)\"\n",
    "    ), margin=dict(l=0, r=0, t=0, b=0))\n",
    "\n",
    "    return fig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"llama\", in_domain_demonstrations=False)\n",
    "fig = plot_feature_distribution(configs, labels)\n",
    "fig.update_layout(font_family=\"CMU Serif\")\n",
    "fig.write_image(\"plots/llama_distribution.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"mistral\", in_domain_demonstrations=False)\n",
    "fig = plot_feature_distribution(configs, labels)\n",
    "fig.update_layout(font_family=\"CMU Serif\")\n",
    "fig.write_image(\"plots/mistral_distribution.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"gemma\", in_domain_demonstrations=False)\n",
    "fig = plot_feature_distribution(configs, labels)\n",
    "fig.update_layout(font_family=\"CMU Serif\")\n",
    "fig.write_image(\"plots/gemma_distribution.pdf\")\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results Baseline vs Steering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "from plotly.colors import n_colors\n",
    "import plotly.express as px\n",
    "from plotly.subplots import make_subplots\n",
    "import plotly.graph_objects as go\n",
    "from plotly.colors import n_colors\n",
    "import torch\n",
    "import numpy as np\n",
    "import einops\n",
    "\n",
    "column_map = {\n",
    "    \"baseline\": \"Baseline: Intent Instruction\",\n",
    "    \"with_instruction\": \"Steering: Same Instruction\",\n",
    "    \"against_instruction\": \"Steering: Opposite Instruction\",\n",
    "    \"one_word\": \"Steering: Only One Word Instruction\",\n",
    "    \"one_word_instruction\": \"Steering: One Word Instruction and Same Instruction\",\n",
    "    \"no_instruction\": \"Steering: No Instruction\",\n",
    "    \"baseline_one_word_instruction\": \"Baseline: Intent + One Word Instruction\"\n",
    "}\n",
    "metric_map = {\n",
    "    \"acc\": \"Accuracy\",\n",
    "    \"pair_acc\": \"PairAcc\"\n",
    "}\n",
    "\n",
    "from pycolors import TailwindColorPalette\n",
    "\n",
    "def plot_das_results(baseline_configs, steering_configs, metric='pair_acc', COLORS=TailwindColorPalette(), extended_legend=False):\n",
    "    \n",
    "    _data = {key: get_results(*value)[metric] for key, value in baseline_configs}\n",
    "    _steering_data = {key: get_results(*value)[metric] for key, value in steering_configs}\n",
    "    \n",
    "    fig = go.Figure()\n",
    "\n",
    "    color_no_instruction = to_rgb(COLORS.get_shade(1, 500))\n",
    "    color_one_word = to_rgb(COLORS.get_shade(1, 500))\n",
    "    color_against = to_rgb(COLORS.get_shade(1, 500))\n",
    "    color_baseline = to_rgb(COLORS.get_shade(4, 300))\n",
    "    # Define colors for each column\n",
    "    colors = {\n",
    "        'baseline': [f\"rgb({','.join(map(str, color_baseline))})\"] * len(_data),\n",
    "        'no_instruction': [f\"rgb({','.join(map(str, color_no_instruction))})\"] * len(_data),\n",
    "    }\n",
    "    for column, data in zip([\"baseline\", \"no_instruction\"], [_data, _steering_data]):    \n",
    "        y_values = []\n",
    "        x_labels = []\n",
    "        for key in data.keys():\n",
    "            y_values.append(data[key])\n",
    "            x_labels.append(format_label(key))\n",
    "        \n",
    "        fig.add_trace(go.Bar(\n",
    "            name=column_map[column],\n",
    "            x=x_labels,\n",
    "            y=y_values,\n",
    "            text=[f'{v:.2f}' if v != 0 else '0' for v in y_values],\n",
    "            textposition=['auto' if v != 0 else 'outside' for v in y_values],\n",
    "            # textfont=dict(size=20),\n",
    "            marker_color=colors[column]\n",
    "        ))\n",
    "\n",
    "\n",
    "    fig.update_layout(\n",
    "        barmode='group',\n",
    "        # title=f'Feature F_{{w}} Causality - {metric_map[metric]}',\n",
    "        # xaxis_title='Model Configuration',\n",
    "        yaxis_title=metric_map[metric],\n",
    "        font=dict(size=16),\n",
    "        xaxis=dict(tickangle=-25),\n",
    "        width=2000,\n",
    "        height=700\n",
    "    )\n",
    "\n",
    "    fig.update_layout(legend=dict(\n",
    "        yanchor=\"bottom\",\n",
    "        y=0.88,\n",
    "        xanchor=\"right\",\n",
    "        x=0.992,\n",
    "        orientation=\"h\",\n",
    "        borderwidth=4,\n",
    "        bordercolor=\"white\"\n",
    "    ))\n",
    "    # xrange\n",
    "    fig.update_yaxes(range=[0.0, 1.0])\n",
    "    # font\n",
    "    fig.update_layout(font=dict(size=25))\n",
    "    \n",
    "    \n",
    "    if extended_legend:\n",
    "        #Add custom legend for color codes\n",
    "        custom_annotations = [\n",
    "            go.Scatter(\n",
    "                x=[None],\n",
    "                y=[None],\n",
    "                mode='markers',\n",
    "                marker=dict(\n",
    "                    size=10,\n",
    "                    color=get_label_color(\"FT\", COLORS)\n",
    "                ),\n",
    "                legendgroup='config',\n",
    "                showlegend=True,\n",
    "                name=' Finetuning (FT)',\n",
    "                legend = 'legend2'\n",
    "            ),\n",
    "            go.Scatter(\n",
    "                x=[None],\n",
    "                y=[None],\n",
    "                mode='markers',\n",
    "                marker=dict(\n",
    "                    size=10,\n",
    "                    color=get_label_color(\"FS\", COLORS)\n",
    "                ),\n",
    "                legendgroup='config',\n",
    "                showlegend=True,\n",
    "                name=' In-Context Learning (ICL)',\n",
    "                legend = 'legend2'\n",
    "\n",
    "            ),\n",
    "            go.Scatter(\n",
    "                x=[None],\n",
    "                y=[None],\n",
    "                mode='markers',\n",
    "                marker=dict(\n",
    "                    size=10,\n",
    "                    color=get_label_color(\"ZS\", COLORS)\n",
    "                ),\n",
    "                legendgroup='config',\n",
    "                showlegend=True,\n",
    "                name=' Zero-Shot (ZS)',\n",
    "                legend = 'legend2'\n",
    "            )\n",
    "        ]\n",
    "\n",
    "        # Add the custom legend to the figure\n",
    "        for annotation in custom_annotations:\n",
    "            fig.add_trace(annotation)\n",
    "\n",
    "\n",
    "        # Add custom legend using annotations\n",
    "        legend_annotations = [\n",
    "            dict(\n",
    "            x=.995,\n",
    "            y=0.79,\n",
    "            xref=\"paper\",\n",
    "            yref=\"paper\",\n",
    "            text=f\"🫵  IF = instruction<br>1️⃣  IF = float\",\n",
    "            showarrow=False,\n",
    "            font=dict(size=20),\n",
    "            align=\"left\",\n",
    "            bgcolor=\"rgba(255, 255, 255, 0.9)\",\n",
    "            borderpad=10,\n",
    "            xanchor=\"right\",\n",
    "            height=48,\n",
    "            width=292,\n",
    "            yanchor=\"top\")\n",
    "            \n",
    "        ]\n",
    "        fig.update_layout(annotations=legend_annotations)\n",
    "\n",
    "    fig.update_layout(legend2=dict(\n",
    "        yanchor=\"top\",\n",
    "        y=0.98,\n",
    "        xanchor=\"right\",\n",
    "        x=0.995, #48\n",
    "        orientation=\"v\",\n",
    "        bgcolor=\"rgba(255, 255, 255, 0.9)\"\n",
    "    ), margin=dict(l=0, r=0, t=0, b=0))\n",
    "\n",
    "    return fig\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"llama\", in_domain_demonstrations=False)\n",
    "fig = plot_das_results(configs, steering_configs, metric='pair_acc')\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=1800, height=450, legend=dict(orientation=\"v\", yanchor=\"top\", y=0.985, xanchor=\"right\", x=0.995, font=dict(size=25)))\n",
    "fig.write_image(f\"plots/llama_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!ls data/BaseFakepedia/BaseFakepedia_nodup_relpid-ts2048/3/models/Meta-Llama-3.1-8B-Instruct-peftq_proj_k_proj_v_proj_o_proj-bs8-ga2-cwf_instruction/results/\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"mistral\", in_domain_demonstrations=False)\n",
    "fig = plot_das_results(configs, steering_configs, metric='pair_acc')\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=1800, height=550, legend=dict(orientation=\"v\", yanchor=\"top\", y=0.985, xanchor=\"right\", x=0.995, font=dict(size=25)))\n",
    "fig.write_image(f\"plots/mistral_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs, steering_configs = get_configurations(\"gemma\", in_domain_demonstrations=False)\n",
    "fig = plot_das_results(configs, steering_configs, metric='pair_acc')\n",
    "fig.update_layout(font_family=\"CMU Serif\", width=1800, height=550, legend=dict(orientation=\"v\", yanchor=\"top\", y=0.985, xanchor=\"right\", x=0.995, font=dict(size=25)))\n",
    "fig.write_image(f\"plots/gemma_pair_acc.pdf\")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "default",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
