{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Angular Steering evaluation visualization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from configs import MAX_SIM_DIR_ID, MAX_NORM_DIR_ID\n",
    "\n",
    "dir_id = \"max_sim\"\n",
    "\n",
    "if dir_id == \"max_sim\":\n",
    "    DIR_ID_MAP = MAX_SIM_DIR_ID\n",
    "elif dir_id == \"max_norm\":\n",
    "    DIR_ID_MAP = MAX_NORM_DIR_ID\n",
    "\n",
    "visualization_dir = Path(f\"visualization/{dir_id}/\")\n",
    "visualization_dir.mkdir(parents=True, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Refusal score and Harmful scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import plotly\n",
    "from pathlib import Path\n",
    "from plotly.subplots import make_subplots\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "model_ids = [\n",
    "    \"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-14B-Instruct\",\n",
    "    \"meta-llama/Llama-3.2-3B-Instruct\",\n",
    "    \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "    \"google/gemma-2-9b-it\",\n",
    "]\n",
    "\n",
    "language = \"en\"\n",
    "data_type = \"harmful\"\n",
    "num_row = 3\n",
    "num_col = len(model_ids) // num_row\n",
    "data_path = \"output/\"\n",
    "\n",
    "adaptive = True\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=num_row,\n",
    "    cols=num_col,\n",
    "    specs=[[{\"type\": \"polar\"}] * num_col] * num_row,\n",
    "    subplot_titles=[model_id.split(\"/\")[1] for model_id in model_ids],\n",
    "    vertical_spacing=0.01,\n",
    "    horizontal_spacing=0.1,\n",
    ")\n",
    "\n",
    "\n",
    "colour_map = {\n",
    "    \"substring_matching\": plotly.colors.qualitative.Plotly[1],\n",
    "    \"llamaguard3\": plotly.colors.qualitative.Plotly[2],\n",
    "    \"harmbench\": plotly.colors.qualitative.Plotly[0],\n",
    "}\n",
    "\n",
    "method_colour_map = {\n",
    "    \"PID\": plotly.colors.qualitative.Plotly[0],\n",
    "    \"DIM\": plotly.colors.qualitative.Plotly[1],\n",
    "    \"PI\" : plotly.colors.qualitative.Plotly[2],\n",
    "    \"PD\" : plotly.colors.qualitative.Plotly[3],\n",
    "    \"RePE\": plotly.colors.qualitative.Plotly[4],\n",
    "    \"ITI\":plotly.colors.qualitative.Plotly[5],\n",
    "}\n",
    "\n",
    "categories = list(str(i) for i in range(0, 360, 10))\n",
    "categories.append(categories[0])\n",
    "\n",
    "for idx, model_id in enumerate(model_ids):\n",
    "    _, model_name = model_id.split(\"/\")\n",
    "    output_path = Path(data_path) / model_name\n",
    "\n",
    "    if adaptive:\n",
    "        glob_pattern = f\"*eval-mode_1-llamaguard3*.json\"\n",
    "    else:\n",
    "        glob_pattern = \"*eval-[!(mode)(perp)]*.json\"\n",
    "\n",
    "    for file in sorted(list(output_path.glob(glob_pattern))):\n",
    "        if \"PID_\" in file.stem.split(\"-\")[0]:\n",
    "            method = \"PID\"\n",
    "        elif \"PI_\" in file.stem.split(\"-\")[0] :\n",
    "            method = \"PI\"\n",
    "        elif \"PD_\" in  file.stem.split(\"-\")[0] :\n",
    "            method = \"PD\" \n",
    "        elif \"RePE_\" in file.stem.split(\"-\")[0] :\n",
    "            method = \"RePE\"\n",
    "        elif \"ITI_\" in file.stem.split(\"-\")[0] :\n",
    "            method = \"ITI\"\n",
    "        else:\n",
    "            method = \"DIM\"\n",
    "        \n",
    "        print(method)\n",
    "        if adaptive:\n",
    "            metric = file.stem.split(\"-\")[2]\n",
    "        else:\n",
    "            metric = file.stem.split(\"-\")[1]\n",
    "        print(metric)\n",
    "        if metric == \"llmjudge\":\n",
    "            continue\n",
    "        if metric not in colour_map:\n",
    "            colour_map[metric] = plotly.colors.qualitative.Plotly[len(colour_map)]\n",
    "\n",
    "        with open(file, \"r\") as f:\n",
    "            eval_data = json.load(f)\n",
    "        print(file)\n",
    "        baseline = eval_data[\"baseline\"]\n",
    "        if isinstance(baseline, list):\n",
    "            baseline = np.mean(baseline)\n",
    "        fig.add_trace(\n",
    "            go.Scatterpolar(\n",
    "                r=[\n",
    "                    (\n",
    "                        1 - baseline\n",
    "                        if metric in [\"substring_matching\", \"llmjudge\"]\n",
    "                        else baseline\n",
    "                    )\n",
    "                    for _ in range(len(categories))\n",
    "                ],\n",
    "                theta=categories,\n",
    "                name=\"baseline\",\n",
    "                line=dict(width=2, color=colour_map[metric], dash=\"dot\"),\n",
    "                # legendgroup=\"legend\"\n",
    "                mode=\"lines\",\n",
    "                opacity=0.5,\n",
    "                showlegend=False,\n",
    "            ),\n",
    "            row=idx // num_col + 1,\n",
    "            col=idx % num_col + 1,\n",
    "        )\n",
    "\n",
    "        print(eval_data.keys())\n",
    "        chosen_plane_id = [\n",
    "            s\n",
    "            for s in eval_data.keys()\n",
    "            if \"dir_random\" not in s and \"baseline\" not in s\n",
    "        ]\n",
    "        print(chosen_plane_id)\n",
    "        if not chosen_plane_id:\n",
    "            continue\n",
    "        print(\"hehe\")\n",
    "        print(model_id, chosen_plane_id)\n",
    "        chosen_plane_id = chosen_plane_id[0]\n",
    "        values = [eval_data[chosen_plane_id][cat] for cat in categories]\n",
    "\n",
    "        # if metric == \"llmjudge\":\n",
    "        #     values = [[0.5 if v == 0.75 else v for v in val] for val in values]\n",
    "        values = [np.mean(val) if isinstance(val, list) else val for val in values]\n",
    "        values.append(values[0])\n",
    "        fig.add_trace(\n",
    "            go.Scatterpolar(\n",
    "                r=(\n",
    "                    [1 - v for v in values]\n",
    "                    if metric in [\"substring_matching\", \"llmjudge\"]\n",
    "                    else values\n",
    "                ),\n",
    "                theta=categories,\n",
    "                legend=f\"legend{idx if idx > 0 else ''}\",\n",
    "                # fill=\"toself\",\n",
    "                name=method,\n",
    "                line=dict(width=2, color=method_colour_map[method]),\n",
    "                # legendgroup=\"legend\"\n",
    "                mode=\"lines\",\n",
    "                # opacity=0.5,\n",
    "                showlegend=idx == 0,\n",
    "            ),\n",
    "            row=idx // num_col + 1,\n",
    "            col=idx % num_col + 1,\n",
    "        )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(\n",
    "            r=[1.02],\n",
    "            # theta=[\"0\"],\n",
    "            name=\"feature direction\",\n",
    "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
    "            mode=\"markers\",\n",
    "            showlegend=idx == 0,\n",
    "        ),\n",
    "        row=idx // num_col + 1,\n",
    "        col=idx % num_col + 1,\n",
    "    )\n",
    "\n",
    "    # fig.add_trace(\n",
    "    #     go.Scatterpolar(\n",
    "    #         r=[1],\n",
    "    #         theta=['90'],\n",
    "    #         name=\"1st PC direction\",\n",
    "    #         marker=dict(size=10, symbol=2, color=plotly.colors.qualitative.Plotly[len(colour_map) + 1]),\n",
    "    #         mode=\"markers\",\n",
    "    #         showlegend=idx == 0,\n",
    "    #     ),\n",
    "    #     row=idx // num_col + 1,\n",
    "    #     col=idx % num_col + 1,\n",
    "    # )\n",
    "\n",
    "for i in range(len(model_ids) + 1):\n",
    "    polar_key = f'polar{i if i > 0 else \"\"}'\n",
    "    fig.update_layout(\n",
    "        {\n",
    "            polar_key: dict(\n",
    "                radialaxis=dict(\n",
    "                    visible=True,\n",
    "                    dtick=0.2,\n",
    "                    tickfont=dict(size=20),\n",
    "                ),\n",
    "                angularaxis=dict(\n",
    "                    # direction=\"clockwise\",\n",
    "                    # rotation=0,  # 0 degrees at the right (East)\n",
    "                    # period=360,\n",
    "                    # tickmode=\"array\",\n",
    "                    tickvals=list(\n",
    "                        range(0, 360, 10)\n",
    "                    ),  # Show all degrees in 10 intervals\n",
    "                    ticktext=[\n",
    "                        f\"{i}°\" for i in range(0, 360, 10)\n",
    "                    ],  # Show all tick labels\n",
    "                    tickfont=dict(size=18),  # Smaller font to fit all labels\n",
    "                    dtick=10,\n",
    "                ),\n",
    "            )\n",
    "        }\n",
    "    )\n",
    "\n",
    "\n",
    "# Global layout settings\n",
    "fig.update_layout(\n",
    "    height=2000,\n",
    "    width=1200,\n",
    "    # title_text=\"Angular Steering Effects on Model Performance\",\n",
    "    showlegend=True,\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        # yanchor=\"top\",\n",
    "        y=0.0,\n",
    "        xanchor=\"center\",\n",
    "        x=0.5,\n",
    "        # entrywidth=0,\n",
    "        font=dict(size=30),\n",
    "    ),\n",
    "    margin=dict(l=50, r=50, t=20, b=0),\n",
    ")\n",
    "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "if adaptive:\n",
    "    output_name = f\"eval_adaptive-harmness-all_models-vertical\"\n",
    "else:\n",
    "    output_name = f\"eval-harmness-all_models-vertical\"\n",
    "\n",
    "fig.write_image(\n",
    "    visualization_dir / f\"PID_{output_name}.pdf\",\n",
    "    width=1200,\n",
    "    height=2000,\n",
    "    scale=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !plotly_get_chrome"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LLM-as-a-judge classification results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ids = [\n",
    "    \"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-14B-Instruct\",\n",
    "    \"meta-llama/Llama-3.2-3B-Instruct\",\n",
    "    \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "    \"google/gemma-2-9b-it\",\n",
    "]\n",
    "\n",
    "language = \"en\"\n",
    "data_type = \"harmful\"\n",
    "num_row = 1\n",
    "num_col = len(model_ids) // num_row\n",
    "data_path = \"output/\"\n",
    "adaptive = True\n",
    "\n",
    "\n",
    "colour_map = {\n",
    "    \"direct\": \"red\",\n",
    "    \"indirect\": \"orange\",\n",
    "    \"redirect\": \"teal\",\n",
    "    \"refusal\": \"green\",\n",
    "}\n",
    "score2label = {\n",
    "    0: \"refusal\",\n",
    "    0.25: \"redirect\",\n",
    "    0.5: \"indirect\",\n",
    "    0.75: \"indirect\",\n",
    "    1: \"direct\",\n",
    "}\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=num_row,\n",
    "    cols=num_col,\n",
    "    specs=[[{\"type\": \"polar\"}] * num_col] * num_row,\n",
    "    subplot_titles=[model_id.split(\"/\")[1] for model_id in model_ids],\n",
    "    vertical_spacing=0.01,\n",
    "    horizontal_spacing=0.1,\n",
    ")\n",
    "\n",
    "for idx, model_id in enumerate(model_ids):\n",
    "    _, model_name = model_id.split(\"/\")\n",
    "    output_path = Path(data_path) / model_name\n",
    "\n",
    "    if adaptive:\n",
    "        file = list(output_path.glob(\"*eval-mode_1-llmjudge-*.json\"))[0]\n",
    "    else:\n",
    "        file = list(output_path.glob(\"*eval-llmjudge-*.json\"))[0]\n",
    "    print(file)\n",
    "    with open(file, \"r\") as f:\n",
    "        eval_data = json.load(f)\n",
    "\n",
    "    chosen_plane_id = [s for s in eval_data.keys() if DIR_ID_MAP[model_id] in s and \"dir_random\" not in s]\n",
    "    if len(chosen_plane_id) != 1:\n",
    "        print(f\"Skipping {model_id} due to {len(chosen_plane_id)} planes IDs found: {chosen_plane_id}\")\n",
    "        continue\n",
    "\n",
    "    chosen_plane_id = chosen_plane_id[0]\n",
    "\n",
    "\n",
    "    print(model_id, chosen_plane_id)\n",
    "    categories = list(str(i) for i in range(0, 360, 10))\n",
    "    categories.append(categories[0])\n",
    "    values = [eval_data[chosen_plane_id][cat] for cat in categories]\n",
    "\n",
    "    scores = [1, 0.75, 0.25, 0]\n",
    "    for v in scores:\n",
    "        r = [vals.count(v) / len(vals) for vals in values]\n",
    "        fig.add_trace(\n",
    "            go.Barpolar(\n",
    "                r=r,\n",
    "                # theta=list(str(i) for i in range(0, 360, 10)),\n",
    "                # marker_color=[colour_map[score2label[v]] for v in scores],\n",
    "                marker_color=colour_map[score2label[v]],\n",
    "                name=score2label[v],\n",
    "                showlegend=idx == 0,\n",
    "                marker=dict(\n",
    "                    line=dict(\n",
    "                        width=0.0,\n",
    "                        # color=\"rgba(0,0,0,0)\"\n",
    "                    )\n",
    "                ),\n",
    "            ),\n",
    "            row=idx // num_col + 1,\n",
    "            col=idx % num_col + 1,\n",
    "        )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(\n",
    "            r=[1.02],\n",
    "            theta=[\"0\"],\n",
    "            name=\"feature direction\",\n",
    "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
    "            mode=\"markers\",\n",
    "            showlegend=idx == 0,\n",
    "        ),\n",
    "        row=idx // num_col + 1,\n",
    "        col=idx % num_col + 1,\n",
    "    )\n",
    "\n",
    "\n",
    "polar_config = dict(\n",
    "    bargap=0,\n",
    "    hole=0,\n",
    "    angularaxis=dict(tickfont=dict(size=16), dtick=10),\n",
    "    radialaxis=dict(\n",
    "        showticklabels=False,\n",
    "        showline=False,\n",
    "    ),\n",
    ")\n",
    "fig.update_layout(\n",
    "    polar=polar_config,\n",
    "    polar2=polar_config,\n",
    "    polar3=polar_config,\n",
    "    polar4=polar_config,\n",
    "    polar5=polar_config,\n",
    "    polar6=polar_config,\n",
    "    showlegend=True,\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        # yanchor=\"top\",\n",
    "        y=0.0,\n",
    "        xanchor=\"center\",\n",
    "        x=0.5,\n",
    "        font=dict(size=30),\n",
    "    ),\n",
    "    height=2000,\n",
    "    width=1200,\n",
    "    margin=dict(l=50, r=50, t=20, b=0),\n",
    ")\n",
    "\n",
    "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(\n",
    "    visualization_dir / \"eval_adaptive-llmjudge-all_models-vertical.pdf\",\n",
    "    width=1200,\n",
    "    height=2000,\n",
    "    scale=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptive Angular Steering on tinyBenchmark\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "from plotly.subplots import make_subplots\n",
    "import re\n",
    "from glob import glob\n",
    "from pathlib import Path\n",
    "from collections import defaultdict\n",
    "\n",
    "# Base directory for benchmarks\n",
    "base_dir = f\"benchmarks_old/{dir_id}/\"\n",
    "# base_dir = f\"benchmarks_old/{dir_id}/\"\n",
    "\n",
    "# Dictionary to map tasks to their primary metric(s)\n",
    "task_to_metrics = {\n",
    "    \"tinyArc\": [\"acc_norm,none\"],\n",
    "    \"tinyGSM8k\": [\n",
    "        \"exact_match,strict-match\",\n",
    "        \"exact_match,flexible-extract\",\n",
    "    ],\n",
    "    \"tinyHellaswag\": [\"acc_norm,none\"],\n",
    "    \"tinyMMLU\": [\"acc_norm,none\"],\n",
    "    \"tinyTruthfulQA\": [\"acc,none\"],\n",
    "    \"tinyWinogrande\": [\"acc_norm,none\"],\n",
    "}\n",
    "\n",
    "\n",
    "models = [\n",
    "    \"Qwen2.5-3B-Instruct\",\n",
    "    \"Qwen2.5-7B-Instruct\",\n",
    "    \"Qwen2.5-14B-Instruct\",\n",
    "    \"Llama-3.2-3B-Instruct\",\n",
    "    \"Llama-3.1-8B-Instruct\",\n",
    "    \"gemma-2-9b-it\",\n",
    "]\n",
    "\n",
    "# Collect data for all models and tasks\n",
    "results = defaultdict(lambda: defaultdict(dict))\n",
    "baselines = defaultdict(dict)\n",
    "\n",
    "# Debug tracking\n",
    "found_metrics = defaultdict(set)\n",
    "\n",
    "for task_dir in os.listdir(base_dir):\n",
    "    task_path = os.path.join(base_dir, task_dir)\n",
    "    task_path = os.path.join(task_path, \"\")\n",
    "    if not os.path.isdir(task_path):\n",
    "        continue\n",
    "\n",
    "    for model_dir in os.listdir(task_path):\n",
    "        if model_dir not in models:\n",
    "            continue\n",
    "\n",
    "        model_path = os.path.join(task_path, model_dir)\n",
    "        if not os.path.isdir(model_path):\n",
    "            continue\n",
    "\n",
    "        # Get metrics for this task\n",
    "        metrics = task_to_metrics.get(task_dir, [])\n",
    "        if not metrics:\n",
    "            continue\n",
    "\n",
    "        # Find all results files for this model and task\n",
    "        for degree in list(range(0, 360, 10)) + [\"none\"]:\n",
    "            pattern = os.path.join(\n",
    "                model_path, f\"adaptive_{degree}\", \"*\", \"results_*.json\"\n",
    "            )\n",
    "            result_files = glob(pattern)\n",
    "\n",
    "            if not result_files:\n",
    "                continue\n",
    "\n",
    "            # Since the result files are named by date, sort them to get the latest\n",
    "            result_fiile = sorted(result_files)\n",
    "            # Use the last result file found for each degree\n",
    "            result_file = result_files[-1]\n",
    "\n",
    "            try:\n",
    "                with open(result_file, \"r\") as f:\n",
    "                    data = json.load(f)\n",
    "\n",
    "                # For each metric defined for this task\n",
    "                for metric in metrics:\n",
    "                    # Create a display name for the metric\n",
    "                    if task_dir == \"tinyGSM8k\":\n",
    "                        # For GSM8k, include the metric type in the display name\n",
    "                        # Extract strict-match or loose-match\n",
    "                        metric_type = (metric.split(\",\")[1]).split(\"-\")[0]\n",
    "                        display_name = f\"{task_dir} ({metric_type})\"\n",
    "                    else:\n",
    "                        display_name = task_dir\n",
    "\n",
    "                    if (\n",
    "                        task_dir in data.get(\"results\", {})\n",
    "                        and metric in data[\"results\"][task_dir]\n",
    "                    ):\n",
    "                        score = data[\"results\"][task_dir][metric]\n",
    "                        found_metrics[task_dir].add(\n",
    "                            metric\n",
    "                        )  # Track which metrics were found\n",
    "\n",
    "                        if degree == \"none\":\n",
    "                            baselines[model_dir][display_name] = score\n",
    "                        else:\n",
    "                            results[model_dir][display_name][int(degree)] = score\n",
    "            except Exception as e:\n",
    "                print(f\"Error processing {result_file}: {e}\")\n",
    "\n",
    "# Print metrics found for debugging\n",
    "print(\"Found metrics by task:\")\n",
    "for task, metrics in found_metrics.items():\n",
    "    print(f\"{task}: {metrics}\")\n",
    "\n",
    "# Create subplots - arrange based on number of models\n",
    "n_models = len(models)\n",
    "n_cols = min(2, n_models)\n",
    "n_rows = (n_models + n_cols - 1) // n_cols\n",
    "# n_cols = 1\n",
    "# n_rows = n_models\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=n_rows,\n",
    "    cols=n_cols,\n",
    "    specs=[[{\"type\": \"polar\"} for _ in range(n_cols)] for _ in range(n_rows)],\n",
    "    subplot_titles=models,\n",
    "    vertical_spacing=0.01,\n",
    "    horizontal_spacing=0.1,\n",
    ")\n",
    "\n",
    "# Add reference baseline trace only once (for legend)\n",
    "fig.add_trace(\n",
    "    go.Scatterpolar(\n",
    "        r=[0],\n",
    "        theta=[0],\n",
    "        name=\"baseline\",\n",
    "        line=dict(width=2, color=\"black\", dash=\"dot\"),\n",
    "        mode=\"lines\",\n",
    "        opacity=0.5,\n",
    "        showlegend=True,\n",
    "    ),\n",
    "    row=1,\n",
    "    col=1,\n",
    ")\n",
    "\n",
    "# Use Plotly's built-in color palette\n",
    "colors = px.colors.qualitative.Plotly\n",
    "\n",
    "# Plot data for each model\n",
    "for model_idx, model_name in enumerate(models):\n",
    "    row = model_idx // n_cols + 1\n",
    "    col = model_idx % n_cols + 1\n",
    "\n",
    "    if model_name not in results:\n",
    "        continue\n",
    "\n",
    "    # Add reference arrow at 0 degrees\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(\n",
    "            r=[1.02],\n",
    "            theta=[0],\n",
    "            name=\"feature direction\",\n",
    "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
    "            mode=\"markers\",\n",
    "            showlegend=model_idx == 0,  # Only show in legend for first model\n",
    "        ),\n",
    "        row=row,\n",
    "        col=col,\n",
    "    )\n",
    "\n",
    "    # Track which tasks we've displayed for this model\n",
    "    task_color_index = 0\n",
    "\n",
    "    # Sort the tasks for consistent color assignment\n",
    "    sorted_tasks = sorted(results[model_name].keys())\n",
    "    \n",
    "    for task_display_name in sorted_tasks:\n",
    "        task_color = colors[task_color_index % len(colors)]\n",
    "        task_color_index += 1\n",
    "\n",
    "        # Prepare angle and score data\n",
    "        angles = sorted(results[model_name][task_display_name].keys())\n",
    "        scores = [results[model_name][task_display_name][angle] for angle in angles]\n",
    "\n",
    "        # Close the loop for a complete polar plot\n",
    "        angles_closed = np.append(angles, angles[0])\n",
    "        scores_closed = np.append(scores, scores[0])\n",
    "\n",
    "        # Plot the task scores with solid lines\n",
    "        fig.add_trace(\n",
    "            go.Scatterpolar(\n",
    "                r=scores_closed,\n",
    "                theta=angles_closed,\n",
    "                name=task_display_name,\n",
    "                line=dict(width=2, color=task_color),\n",
    "                mode=\"lines\",\n",
    "                showlegend=model_idx == 0,  # Only show in legend for first model\n",
    "            ),\n",
    "            row=row,\n",
    "            col=col,\n",
    "        )\n",
    "        \n",
    "        # Add baseline as dashed circular line if available\n",
    "        if model_name in baselines and task_display_name in baselines[model_name]:\n",
    "            baseline_score = baselines[model_name][task_display_name]\n",
    "            baseline_theta = np.linspace(0, 360, 361)\n",
    "\n",
    "            # Use dash line for baselines but don't add to legend\n",
    "            fig.add_trace(\n",
    "                go.Scatterpolar(\n",
    "                    r=[baseline_score] * len(baseline_theta),\n",
    "                    theta=baseline_theta,\n",
    "                    line=dict(width=1.5, color=task_color, dash=\"dot\"),\n",
    "                    mode=\"lines\",\n",
    "                    opacity=0.5,\n",
    "                    showlegend=False,  # Don't add individual baselines to legend\n",
    "                ),\n",
    "                row=row,\n",
    "                col=col,\n",
    "            )\n",
    "    print([(model_name, task, np.min(list(results[model_name][task].values()))) for task in results[model_name].keys() for task in sorted_tasks])\n",
    "    # print(model_name,baselines[model_name])\n",
    "\n",
    "# Update layout for each subplot\n",
    "for i in range(1, n_models + 1):\n",
    "    row = (i - 1) // n_cols + 1\n",
    "    col = (i - 1) % n_cols + 1\n",
    "\n",
    "    # Get the corresponding polar subplot key\n",
    "    polar_key = f'polar{i if i > 1 else \"\"}'\n",
    "\n",
    "    # Set consistent range and angle markings across plots\n",
    "    fig.update_layout(\n",
    "        {\n",
    "            polar_key: dict(\n",
    "                radialaxis=dict(\n",
    "                    visible=True,\n",
    "                    range=[0.0, 1.02],\n",
    "                    dtick=0.2,\n",
    "                    # tickvals=[0.3, 0.5, 0.7, 0.9, 1],\n",
    "                    tickfont=dict(size=22),\n",
    "                ),\n",
    "                angularaxis=dict(\n",
    "                    # direction=\"clockwise\",\n",
    "                    # rotation=0,  # 0 degrees at the right (East)\n",
    "                    # period=360,\n",
    "                    # tickmode=\"array\",\n",
    "                    # tickvals=list(\n",
    "                    #     range(0, 360, 10)\n",
    "                    # ),  # Show all degrees in 10 intervals\n",
    "                    # ticktext=[\n",
    "                    #     f\"{i}°\" for i in range(0, 360, 10)\n",
    "                    # ],  # Show all tick labels\n",
    "                    tickfont=dict(size=16),  # Smaller font to fit all labels\n",
    "                    dtick=10,\n",
    "                ),\n",
    "            )\n",
    "        }\n",
    "    )\n",
    "\n",
    "# Global layout settings\n",
    "fig.update_layout(\n",
    "    height=2000,\n",
    "    width=1200,\n",
    "    # title_text=\"Angular Steering Effects on Model Performance\",\n",
    "    showlegend=True,\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        # yanchor=\"top\",\n",
    "        y=0.0,\n",
    "        xanchor=\"center\",\n",
    "        x=0.5,\n",
    "        # entrywidth=0,\n",
    "        font=dict(size=30),\n",
    "    ),\n",
    "    margin=dict(l=50, r=50, t=20, b=0),\n",
    ")\n",
    "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig.write_image(\n",
    "    visualization_dir / \"PID_eval_adaptive-tinyBenchmark-all_models-vertical.pdf\",\n",
    "    width=1200,\n",
    "    height=2000,\n",
    "    scale=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perplexity scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import plotly.graph_objects as go\n",
    "import plotly.express as px\n",
    "from pathlib import Path\n",
    "import glob\n",
    "\n",
    "\n",
    "def filter_perplexity(perplexity_data, threshold=100):\n",
    "    \"\"\"\n",
    "    Filter perplexity data to remove outliers.\n",
    "    E.g. in `qwen2.5-3B/rotated/max_sim-pca/290/34` there is a perplexity of 2.8e6\n",
    "    because the generation is empty\n",
    "    \"\"\"\n",
    "    return [p for p in perplexity_data if p < threshold]\n",
    "\n",
    "\n",
    "MODELS = [\n",
    "    \"Qwen/Qwen2.5-3B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-7B-Instruct\",\n",
    "    \"Qwen/Qwen2.5-14B-Instruct\",\n",
    "    \"meta-llama/Llama-3.2-3B-Instruct\",\n",
    "    \"meta-llama/Llama-3.1-8B-Instruct\",\n",
    "    \"google/gemma-2-9b-it\",\n",
    "]\n",
    "\n",
    "# Base output directory\n",
    "output_dir = Path(\"output\")\n",
    "\n",
    "generation_only = False\n",
    "\n",
    "# Find all perplexity evaluation files\n",
    "if generation_only:\n",
    "    perplexity_files = glob.glob(\n",
    "        str(output_dir / \"*\" / \"eval-perplexity-generation_only.json\")\n",
    "    )\n",
    "else:\n",
    "    perplexity_files = glob.glob(str(output_dir / \"*\" / \"eval-perplexity.json\"))\n",
    "\n",
    "# Create figure with subplots based on number of models\n",
    "n_models = len(MODELS)\n",
    "n_cols = min(2, n_models)\n",
    "n_rows = (n_models + n_cols - 1) // n_cols\n",
    "# n_cols = 1\n",
    "# n_rows = n_models\n",
    "\n",
    "# Colors for different modes\n",
    "colors = px.colors.qualitative.Plotly\n",
    "\n",
    "# Calculate model-specific max perplexity values\n",
    "all_perplexity_data = {}\n",
    "\n",
    "for perplexity_file in perplexity_files:\n",
    "    model_name = Path(perplexity_file).parent.name\n",
    "    with open(perplexity_file, \"r\") as f:\n",
    "        data = json.load(f)\n",
    "        all_perplexity_data[model_name] = data\n",
    "\n",
    "        # Calculate max perplexity for this specific model\n",
    "        model_max = 0\n",
    "\n",
    "        # Check baseline max\n",
    "        if \"harmful-en-baseline\" in data:\n",
    "            model_max = max(model_max, max(data[\"harmful-en-baseline\"]))\n",
    "\n",
    "if n_models > 1:\n",
    "    from plotly.subplots import make_subplots\n",
    "\n",
    "    fig = make_subplots(\n",
    "        rows=n_rows,\n",
    "        cols=n_cols,\n",
    "        specs=[[{\"type\": \"polar\"} for _ in range(n_cols)] for _ in range(n_rows)],\n",
    "        subplot_titles=[m.split('/')[-1] for m in MODELS],\n",
    "        vertical_spacing=0.01,\n",
    "        horizontal_spacing=0.1,\n",
    "    )\n",
    "else:\n",
    "    fig = go.Figure()\n",
    "\n",
    "# Add reference baseline trace only once for legend\n",
    "fig.add_trace(\n",
    "    go.Scatterpolar(\n",
    "        r=[0],\n",
    "        theta=[0],\n",
    "        name=\"no steering\",\n",
    "        line=dict(width=2, color=\"black\", dash=\"dot\"),\n",
    "        mode=\"lines\",\n",
    "        opacity=0.5,\n",
    "        showlegend=True,\n",
    "    ),\n",
    "    row=1 if n_models > 1 else None,\n",
    "    col=1 if n_models > 1 else None,\n",
    ")\n",
    "\n",
    "# Process each perplexity file\n",
    "for model_idx, model_id in enumerate(MODELS):\n",
    "    chosen_direction = DIR_ID_MAP[model_id]\n",
    "    print(model_id)\n",
    "    # print(chosen_direction)\n",
    "\n",
    "\n",
    "    model_name = model_id.split(\"/\")[-1]\n",
    "    perplexity_data = all_perplexity_data[model_name]\n",
    "\n",
    "    row = model_idx // n_cols + 1 if n_models > 1 else None\n",
    "    col = model_idx % n_cols + 1 if n_models > 1 else None\n",
    "\n",
    "    # Extract modes (rotated, adaptive, etc.)\n",
    "    modes = []\n",
    "    for key in perplexity_data.keys():\n",
    "        if chosen_direction not in key:\n",
    "            continue\n",
    "        print(key)\n",
    "        if key.startswith(\"harmful-en-dir\") and not key.endswith(\"baseline\"):\n",
    "            mode = \"non-adaptive\" if \"rotated\" in key else \"adaptive\"\n",
    "            modes.append((key, mode))\n",
    "    modes.sort()\n",
    "\n",
    "    # Process each mode\n",
    "    for mode_idx, (mode_key, mode_name) in enumerate(modes):\n",
    "        angles = []\n",
    "        mean_perplexities = []\n",
    "\n",
    "        # Calculate mean perplexity for each angle\n",
    "        for angle_str, perplexities in perplexity_data[mode_key].items():\n",
    "            if angle_str.isdigit() and perplexities:\n",
    "                perplexities = filter_perplexity(perplexities)\n",
    "                angle = int(angle_str)\n",
    "                mean_perp = np.mean(perplexities)\n",
    "                angles.append(angle)\n",
    "                mean_perplexities.append(mean_perp)\n",
    "\n",
    "        # Sort by angle\n",
    "        sorted_indices = np.argsort(angles)\n",
    "        angles = np.array(angles)[sorted_indices]\n",
    "        mean_perplexities = np.array(mean_perplexities)[sorted_indices]\n",
    "\n",
    "        # Add one more point to close the loop\n",
    "        if len(angles) > 1:\n",
    "            angles_closed = np.append(angles, angles[0])\n",
    "            perplexities_closed = np.append(mean_perplexities, mean_perplexities[0])\n",
    "\n",
    "            # Add to plot\n",
    "            fig.add_trace(\n",
    "                go.Scatterpolar(\n",
    "                    r=perplexities_closed,\n",
    "                    theta=angles_closed,\n",
    "                    name=f\"{mode_name}\",\n",
    "                    line=dict(width=2, color=colors[mode_idx % len(colors)]),\n",
    "                    mode=\"lines\",\n",
    "                    showlegend=model_idx == 0,  # Only show in legend for first model\n",
    "                ),\n",
    "                row=row,\n",
    "                col=col,\n",
    "            )\n",
    "\n",
    "    max_r = max(mean_perplexities)\n",
    "\n",
    "    # Add reference star at 0 degrees\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(\n",
    "            r=[max_r * 1.02],\n",
    "            theta=[0],\n",
    "            name=\"feature direction\",\n",
    "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
    "            mode=\"markers\",\n",
    "            showlegend=model_idx == 0,  # Only show in legend for first model\n",
    "            legend=\"legend\",\n",
    "        ),\n",
    "        row=row,\n",
    "        col=col,\n",
    "    )\n",
    "\n",
    "    # Add baseline if available\n",
    "    if \"harmful-en-baseline\" in perplexity_data:\n",
    "        baseline_perplexity = np.mean(perplexity_data[\"harmful-en-baseline\"])\n",
    "        max_r = max(max_r, baseline_perplexity)\n",
    "        baseline_theta = np.linspace(0, 360, 100)\n",
    "\n",
    "        fig.add_trace(\n",
    "            go.Scatterpolar(\n",
    "                r=[baseline_perplexity] * len(baseline_theta),\n",
    "                theta=baseline_theta,\n",
    "                name=\"Baseline\",\n",
    "                line=dict(width=1.5, color=\"black\", dash=\"dot\"),\n",
    "                mode=\"lines\",\n",
    "                opacity=0.5,\n",
    "                showlegend=False,  # Don't add individual baselines to legend\n",
    "                legend=\"legend\",\n",
    "            ),\n",
    "            row=row,\n",
    "            col=col,\n",
    "        )\n",
    "\n",
    "\n",
    "# Update layout with model-specific ranges\n",
    "if n_models > 1:\n",
    "    for i in range(1, n_models + 1):\n",
    "        row = (i - 1) // n_cols + 1\n",
    "        col = (i - 1) % n_cols + 1\n",
    "        polar_key = f'polar{i if i > 1 else \"\"}'\n",
    "\n",
    "        fig.update_layout(\n",
    "            {\n",
    "                polar_key: dict(\n",
    "                    radialaxis=dict(\n",
    "                        visible=True,\n",
    "                        # range=[0.3, 1.02],\n",
    "                        # dtick=0.2,\n",
    "                        # tickvals=[0.3, 0.5, 0.7, 0.9, 1],\n",
    "                        tickfont=dict(size=22),\n",
    "                    ),\n",
    "                    angularaxis=dict(\n",
    "                        tickvals=list(range(0, 360, 10)),\n",
    "                        ticktext=[f\"{i}°\" for i in range(0, 360, 10)],\n",
    "                        tickfont=dict(size=16),\n",
    "                    ),\n",
    "                )\n",
    "            }\n",
    "        )\n",
    "else:\n",
    "    fig.update_layout(\n",
    "        polar=dict(\n",
    "            radialaxis=dict(visible=True, title=\"Average Perplexity\"),\n",
    "            angularaxis=dict(\n",
    "                tickvals=list(range(0, 360, 10)),\n",
    "                ticktext=[f\"{i}°\" for i in range(0, 360, 10)],\n",
    "                tickfont=dict(size=16),\n",
    "            ),\n",
    "        ),\n",
    "    )\n",
    "\n",
    "fig.update_layout(\n",
    "    height=2000,\n",
    "    width=1200,\n",
    "    # title_text=\"Angular Steering Effects on Model Perplexity\",\n",
    "    showlegend=True,\n",
    "    # legend=dict(\n",
    "    #     orientation=\"h\",\n",
    "    #     yanchor=\"bottom\",\n",
    "    #     y=-0.1 if n_models > 1 else -0.2,\n",
    "    #     xanchor=\"center\",\n",
    "    #     x=0.5,\n",
    "    # ),\n",
    "    legend=dict(\n",
    "        orientation=\"h\",\n",
    "        # yanchor=\"top\",\n",
    "        y=0.0,\n",
    "        xanchor=\"center\",\n",
    "        x=0.5,\n",
    "        font=dict(size=30),\n",
    "        # entrywidth=100,\n",
    "        entrywidthmode=\"pixels\",\n",
    "    ),\n",
    "    # legend2=dict(\n",
    "    #     orientation=\"h\",\n",
    "    #     # yanchor=\"top\",\n",
    "    #     y=-0.01,\n",
    "    #     xanchor=\"center\",\n",
    "    #     x=0.5,\n",
    "    #     font=dict(size=16),\n",
    "    # ),\n",
    "    margin=dict(l=50, r=50, t=20, b=140),\n",
    ")\n",
    "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "# save figure\n",
    "if generation_only:\n",
    "    output_name = \"eval-ppl-generation_only\"\n",
    "else:\n",
    "    output_name = \"eval-ppl\"\n",
    "fig.write_image(\n",
    "    visualization_dir / f\"{output_name}-vertical.pdf\",\n",
    "    height=2000,\n",
    "    width=1200,\n",
    "    scale=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Geometric interpretation of Angular Steering\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
    "\n",
    "# Define base vectors\n",
    "d_1stPC_3d = np.array([1.5, 0.0, 0.5])\n",
    "d_3d = np.array([0.5, 1.0, 0.0])\n",
    "d_3d = d_3d / np.linalg.norm(d_3d)\n",
    "\n",
    "alpha = 1.0\n",
    "h_add_3d = d_1stPC_3d + alpha * d_3d\n",
    "d_ortho_3d = d_1stPC_3d - np.dot(d_1stPC_3d, d_3d) * d_3d\n",
    "\n",
    "# Normalize versions\n",
    "h_norm = d_1stPC_3d / np.linalg.norm(d_1stPC_3d)\n",
    "h_add_norm = h_add_3d / np.linalg.norm(h_add_3d)\n",
    "h_ablate_norm = d_ortho_3d / np.linalg.norm(d_ortho_3d)\n",
    "\n",
    "# Plane that spans h and d\n",
    "extended_h = 1.2 * d_1stPC_3d\n",
    "extended_d = 1.2 * d_3d\n",
    "origin = np.array([0, 0, 0])\n",
    "corner1 = 1.2 * extended_h + 1.4 * extended_d\n",
    "corner2 = 1.2 * extended_h - 0.8 * extended_d\n",
    "corner3 = 0.5 * -extended_h - 0.8 * extended_d\n",
    "corner4 = 0.5 * -extended_h + 1.4 * extended_d\n",
    "plane_vertices = [corner1, corner2, corner3, corner4]\n",
    "\n",
    "# Create figure and axes\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "ax1 = fig.add_subplot(121, projection=\"3d\")\n",
    "ax2 = fig.add_subplot(122, projection=\"3d\")\n",
    "\n",
    "\n",
    "# Right angle marker\n",
    "def draw_right_angle_marker(ax, vec1, vec2, scale=0.2):\n",
    "    v1 = vec1 / np.linalg.norm(vec1)\n",
    "    v2 = vec2 / np.linalg.norm(vec2)\n",
    "    base = np.array([0, 0, 0])\n",
    "    corner = base + scale * v1\n",
    "    offset1 = corner + scale * v2\n",
    "    offset2 = offset1 - scale * v1\n",
    "    ax.plot(\n",
    "        [corner[0], offset1[0]],\n",
    "        [corner[1], offset1[1]],\n",
    "        [corner[2], offset1[2]],\n",
    "        color=\"gray\",\n",
    "    )\n",
    "    ax.plot(\n",
    "        [offset1[0], offset2[0]],\n",
    "        [offset1[1], offset2[1]],\n",
    "        [offset1[2], offset2[2]],\n",
    "        color=\"gray\",\n",
    "    )\n",
    "\n",
    "\n",
    "# Rotation arc\n",
    "def draw_arc(ax, v_from, v_to, color):\n",
    "    v_from = v_from / np.linalg.norm(v_from)\n",
    "    v_to = v_to / np.linalg.norm(v_to)\n",
    "    theta = np.linspace(0, np.arccos(np.clip(np.dot(v_from, v_to), -1, 1)), 50)\n",
    "    axis = np.cross(v_from, v_to)\n",
    "    axis = axis / np.linalg.norm(axis)\n",
    "    arc = np.array(\n",
    "        [1.0 * (np.cos(t) * v_from + np.sin(t) * np.cross(axis, v_from)) for t in theta]\n",
    "    )\n",
    "    ax.plot(arc[:, 0], arc[:, 1], arc[:, 2], color=color, linestyle=\"--\")\n",
    "\n",
    "\n",
    "def draw_vector_arrow(ax, vec, color, label):\n",
    "    ax.plot([0, vec[0]], [0, vec[1]], [0, vec[2]], color=color, label=label)\n",
    "    direction = vec / np.linalg.norm(vec)\n",
    "    side = np.cross(direction, np.array([0, 0, 1]))\n",
    "    if np.linalg.norm(side) < 1e-6:\n",
    "        side = np.cross(direction, np.array([0, 1, 0]))\n",
    "    side = side / np.linalg.norm(side) * 0.05\n",
    "    head_length = 0.1\n",
    "    tip = vec\n",
    "    left = tip - head_length * direction + side\n",
    "    right = tip - head_length * direction - side\n",
    "    ax.plot([tip[0], left[0]], [tip[1], left[1]], [tip[2], left[2]], color=color)\n",
    "    ax.plot([tip[0], right[0]], [tip[1], right[1]], [tip[2], right[2]], color=color)\n",
    "\n",
    "\n",
    "# General plotting function\n",
    "def plot_vectors(\n",
    "    ax,\n",
    "    h_vec,\n",
    "    h_add_vec,\n",
    "    h_ablate_vec,\n",
    "    title,\n",
    "    show_addition_lines=False,\n",
    "    show_right_angle=False,\n",
    "    show_rotation_arc=True,\n",
    "    show_legend=True,\n",
    "):\n",
    "\n",
    "    draw_vector_arrow(\n",
    "        ax,\n",
    "        h_vec,\n",
    "        color=plotly.colors.qualitative.Plotly[0],\n",
    "        label=\"$\\\\mathbf{h}$ (activation)\",\n",
    "    )\n",
    "    draw_vector_arrow(\n",
    "        ax,\n",
    "        d_3d,\n",
    "        color=plotly.colors.qualitative.Plotly[1],\n",
    "        label=\"$\\\\mathbf{d}_\\\\text{feature}$ (feature direction)\",\n",
    "    )\n",
    "    draw_vector_arrow(\n",
    "        ax,\n",
    "        h_add_vec,\n",
    "        color=plotly.colors.qualitative.Plotly[3],\n",
    "        label=(\n",
    "            \"$\\\\mathbf{h} + \\\\alpha \\\\mathbf{d}_\\\\text{feature}$ (activation addition,\"\n",
    "            \" $\\\\alpha=1$)\"\n",
    "        ),\n",
    "    )\n",
    "    draw_vector_arrow(\n",
    "        ax,\n",
    "        h_ablate_vec,\n",
    "        color=plotly.colors.qualitative.Plotly[2],\n",
    "        label=\"$\\\\mathbf{h}_\\\\perp$ (directional ablation)\",\n",
    "    )\n",
    "\n",
    "    ax.add_collection3d(Poly3DCollection([plane_vertices], color=\"gray\", alpha=0.1))\n",
    "\n",
    "    if show_addition_lines:\n",
    "        ax.plot(\n",
    "            [h_vec[0], h_add_vec[0]],\n",
    "            [h_vec[1], h_add_vec[1]],\n",
    "            [h_vec[2], h_add_vec[2]],\n",
    "            linestyle=\"dashed\",\n",
    "            color=\"gray\",\n",
    "            linewidth=1,\n",
    "        )\n",
    "        ax.plot(\n",
    "            [d_3d[0], h_add_vec[0]],\n",
    "            [d_3d[1], h_add_vec[1]],\n",
    "            [d_3d[2], h_add_vec[2]],\n",
    "            linestyle=\"dashed\",\n",
    "            color=\"gray\",\n",
    "            linewidth=1,\n",
    "        )\n",
    "\n",
    "    if show_right_angle:\n",
    "        draw_right_angle_marker(ax, d_3d, h_ablate_vec)\n",
    "\n",
    "    if show_rotation_arc:\n",
    "        draw_arc(ax, d_3d, h_ablate_vec, (0, 0, 1, 0.3))\n",
    "        # draw_arc(ax, h_vec, h_add_vec, \"orange\")\n",
    "        # draw_arc(ax, h_vec, h_ablate_vec, \"red\")\n",
    "\n",
    "    ax.set_title(title, fontsize=20)\n",
    "    ax.set_xlim([-1, 2])\n",
    "    ax.set_ylim([-1, 2])\n",
    "    ax.set_zlim([-1, 2])\n",
    "    ax.view_init(elev=30, azim=135)\n",
    "    ax.grid(True)\n",
    "    ax.set_xticklabels([])\n",
    "    ax.set_yticklabels([])\n",
    "    ax.set_zticklabels([])\n",
    "    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "\n",
    "    if show_legend:\n",
    "        ax.legend(fontsize=18, loc=\"upper left\", bbox_to_anchor=(0, 0.3))\n",
    "\n",
    "\n",
    "# Plot left: before normalization\n",
    "plot_vectors(\n",
    "    ax1,\n",
    "    d_1stPC_3d,\n",
    "    h_add_3d,\n",
    "    d_ortho_3d,\n",
    "    \"Before Normalization\",\n",
    "    show_addition_lines=True,\n",
    "    show_right_angle=True,\n",
    "    show_rotation_arc=False,\n",
    "    show_legend=False,\n",
    ")\n",
    "\n",
    "# Plot right: after normalization\n",
    "plot_vectors(\n",
    "    ax2,\n",
    "    h_norm,\n",
    "    h_add_norm,\n",
    "    h_ablate_norm,\n",
    "    \"After Normalization\",\n",
    "    show_addition_lines=False,\n",
    "    show_right_angle=True,\n",
    "    show_rotation_arc=True,\n",
    "    show_legend=True,\n",
    ")\n",
    "\n",
    "# Connect the two views\n",
    "fig.text(0.51, 0.5, \"➜\", fontsize=20, ha=\"center\", va=\"center\")\n",
    "# plt.tight_layout()\n",
    "fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, wspace=-0.2)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\"visualization/steering_methods.pdf\", dpi=300, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rotation within a steering plane (wip)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
    "\n",
    "# Define base vectors\n",
    "h_3d = np.array([1, 0.5, 1])\n",
    "d_1stPC_3d = np.array([1.0, -0.5, 0.0])\n",
    "d_3d = np.array([0.0, 1.0, 0.0])\n",
    "d_3d = d_3d / np.linalg.norm(d_3d)\n",
    "\n",
    "alpha = 1.0\n",
    "h_add_3d = d_1stPC_3d + alpha * d_3d\n",
    "d_ortho_3d = d_1stPC_3d - np.dot(d_1stPC_3d, d_3d) * d_3d\n",
    "\n",
    "# Normalize versions\n",
    "h_norm = d_1stPC_3d / np.linalg.norm(d_1stPC_3d)\n",
    "h_add_norm = h_add_3d / np.linalg.norm(h_add_3d)\n",
    "h_ablate_norm = d_ortho_3d / np.linalg.norm(d_ortho_3d)\n",
    "\n",
    "# Plane that spans h and d\n",
    "extended_h = 1.2 * d_ortho_3d\n",
    "extended_d = 1.2 * d_3d\n",
    "origin = np.array([0, 0, 0])\n",
    "corner1 = 1.2 * extended_h + 1.4 * extended_d\n",
    "corner2 = 1.2 * extended_h - 0.8 * extended_d\n",
    "corner3 = 0.5 * -extended_h - 0.8 * extended_d\n",
    "corner4 = 0.5 * -extended_h + 1.4 * extended_d\n",
    "plane_vertices = [corner1, corner2, corner3, corner4]\n",
    "\n",
    "# Create figure and axes\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "ax1 = fig.add_subplot(121, projection=\"3d\")\n",
    "# ax2 = fig.add_subplot(122, projection=\"3d\")\n",
    "\n",
    "\n",
    "# Right angle marker\n",
    "def draw_right_angle_marker(ax, vec1, vec2, scale=0.2):\n",
    "    v1 = vec1 / np.linalg.norm(vec1)\n",
    "    v2 = vec2 / np.linalg.norm(vec2)\n",
    "    base = np.array([0, 0, 0])\n",
    "    corner = base + scale * v1\n",
    "    offset1 = corner + scale * v2\n",
    "    offset2 = offset1 - scale * v1\n",
    "    ax.plot(\n",
    "        [corner[0], offset1[0]],\n",
    "        [corner[1], offset1[1]],\n",
    "        [corner[2], offset1[2]],\n",
    "        color=\"gray\",\n",
    "    )\n",
    "    ax.plot(\n",
    "        [offset1[0], offset2[0]],\n",
    "        [offset1[1], offset2[1]],\n",
    "        [offset1[2], offset2[2]],\n",
    "        color=\"gray\",\n",
    "    )\n",
    "\n",
    "\n",
    "# Rotation arc\n",
    "def draw_arc(ax, v_from, v_to, color):\n",
    "    v_from = v_from / np.linalg.norm(v_from)\n",
    "    v_to = v_to / np.linalg.norm(v_to)\n",
    "    theta = np.linspace(0, np.arccos(np.clip(np.dot(v_from, v_to), -1, 1)), 50)\n",
    "    axis = np.cross(v_from, v_to)\n",
    "    axis = axis / np.linalg.norm(axis)\n",
    "    arc = np.array(\n",
    "        [1.0 * (np.cos(t) * v_from + np.sin(t) * np.cross(axis, v_from)) for t in theta]\n",
    "    )\n",
    "    ax.plot(arc[:, 0], arc[:, 1], arc[:, 2], color=color, linestyle=\"--\")\n",
    "\n",
    "\n",
    "def draw_vector_arrow(ax, vec, color, label):\n",
    "    ax.plot([0, vec[0]], [0, vec[1]], [0, vec[2]], color=color, label=label)\n",
    "    direction = vec / np.linalg.norm(vec)\n",
    "    side = np.cross(direction, np.array([0, 0, 1]))\n",
    "    if np.linalg.norm(side) < 1e-6:\n",
    "        side = np.cross(direction, np.array([0, 1, 0]))\n",
    "    side = side / np.linalg.norm(side) * 0.05\n",
    "    head_length = 0.1\n",
    "    tip = vec\n",
    "    left = tip - head_length * direction + side\n",
    "    right = tip - head_length * direction - side\n",
    "    ax.plot([tip[0], left[0]], [tip[1], left[1]], [tip[2], left[2]], color=color)\n",
    "    ax.plot([tip[0], right[0]], [tip[1], right[1]], [tip[2], right[2]], color=color)\n",
    "\n",
    "\n",
    "# General plotting function\n",
    "\n",
    "title = \"Angular Steering\"\n",
    "show_addition_lines = (True,)\n",
    "show_right_angle = (True,)\n",
    "show_rotation_arc = (False,)\n",
    "show_legend = (False,)\n",
    "\n",
    "draw_vector_arrow(\n",
    "    ax1,\n",
    "    h_3d,\n",
    "    color=plotly.colors.qualitative.Plotly[0],\n",
    "    label=\"$\\\\mathbf{h}$ (activation)\",\n",
    ")\n",
    "\n",
    "draw_vector_arrow(\n",
    "    ax1,\n",
    "    d_3d,\n",
    "    color=plotly.colors.qualitative.Plotly[1],\n",
    "    label=\"$\\\\mathbf{d}_\\\\text{feature}$ (feature direction)\",\n",
    ")\n",
    "\n",
    "draw_vector_arrow(\n",
    "    ax1,\n",
    "    d_1stPC_3d,\n",
    "    color=plotly.colors.qualitative.Plotly[6],\n",
    "    label=\"$\\mathbf{d}_\\\\text{1stPC}$\",\n",
    ")\n",
    "\n",
    "ax1.add_collection3d(Poly3DCollection([plane_vertices], color=\"gray\", alpha=0.1))\n",
    "\n",
    "\n",
    "if show_right_angle:\n",
    "    draw_right_angle_marker(ax1, d_3d, d_1stPC_3d)\n",
    "\n",
    "if show_rotation_arc:\n",
    "    draw_arc(ax1, d_3d, d_1stPC_3d, (0, 0, 1, 0.3))\n",
    "    # draw_arc(ax, h_vec, h_add_vec, \"orange\")\n",
    "    # draw_arc(ax, h_vec, h_ablate_vec, \"red\")\n",
    "\n",
    "ax1.set_title(title, fontsize=20)\n",
    "ax1.set_xlim([-1, 2])\n",
    "ax1.set_ylim([-1, 2])\n",
    "ax1.set_zlim([-1, 2])\n",
    "ax1.view_init(elev=30, azim=135)\n",
    "ax1.grid(True)\n",
    "ax1.set_xticklabels([])\n",
    "ax1.set_yticklabels([])\n",
    "ax1.set_zticklabels([])\n",
    "ax1.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "ax1.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "ax1.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
    "\n",
    "if show_legend:\n",
    "    ax1.legend(fontsize=18, loc=\"upper left\", bbox_to_anchor=(0, 0.3))\n",
    "\n",
    "\n",
    "# Connect the two views\n",
    "# fig.text(0.51, 0.5, \"➜\", fontsize=20, ha=\"center\", va=\"center\")\n",
    "# plt.tight_layout()\n",
    "fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, wspace=-0.2)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "fig.savefig(\"visualization/steering_methods.png\", dpi=300, bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "angular_steering",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
