{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ggLCu-otYY3Z"
      },
      "source": [
        "# Angular Steering evaluation visualization\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9F5uVtz2YY3b"
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "from configs import config_loader\n",
        "\n",
        "dir_id = \"max_sim\"\n",
        "\n",
        "opt_mode = \"adam\"\n",
        "mode = \"causal_actadd1p0\"\n",
        "beta = 0.9\n",
        "beta_2 = 0.999\n",
        "if opt_mode == \"adam\":\n",
        "    beta_str = f'{str(beta).replace(\".\", \"p\")}_{str(beta_2).replace(\".\", \"p\")}'\n",
        "else:\n",
        "    beta_str = str(beta).replace(\".\", \"p\")\n",
        "MAX_NORM_DIR_ID, MAX_SIM_DIR_ID = config_loader(opt_mode, mode, beta_str)\n",
        "\n",
        "if dir_id == \"max_sim\":\n",
        "    DIR_ID_MAP = MAX_SIM_DIR_ID\n",
        "elif dir_id == \"max_norm\":\n",
        "    DIR_ID_MAP = MAX_NORM_DIR_ID\n",
        "\n",
        "visualization_dir = Path(f\"visualization/{dir_id}/\") / f\"{opt_mode}_{mode}\" / f\"beta_{beta_str}\"\n",
        "visualization_dir.mkdir(parents=True, exist_ok=True)\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4zihIrjNYY3b"
      },
      "source": [
        "## Refusal score and Harmful scores\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MdUVnTlcYY3c"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import numpy as np\n",
        "import plotly\n",
        "from pathlib import Path\n",
        "from plotly.subplots import make_subplots\n",
        "import plotly.graph_objects as go\n",
        "\n",
        "model_ids = [\n",
        "    \"Qwen/Qwen2.5-3B-Instruct\",\n",
        "    \"Qwen/Qwen2.5-7B-Instruct\",\n",
        "    \"Qwen/Qwen2.5-14B-Instruct\",\n",
        "    'Qwen/Qwen2.5-32B-Instruct',\n",
        "    \"meta-llama/Llama-3.2-3B-Instruct\",\n",
        "    \"meta-llama/Llama-3.1-8B-Instruct\",\n",
        "    \"google/gemma-2-9b-it\",\n",
        "    'google/gemma-2-27b-it',\n",
        "    \"Unispac/Gemma-2-9B-IT-With-Deeper-Safety-Alignment\",\n",
        "]\n",
        "\n",
        "language = \"en\"\n",
        "data_type = \"harmful\"\n",
        "num_row = 3\n",
        "num_col = len(model_ids) // num_row\n",
        "data_path = \"output/\"\n",
        "\n",
        "adaptive = True\n",
        "\n",
        "fig = make_subplots(\n",
        "    rows=num_row,\n",
        "    cols=num_col,\n",
        "    specs=[[{\"type\": \"polar\"}] * num_col] * num_row,\n",
        "    subplot_titles=[model_id.split(\"/\")[1] for model_id in model_ids],\n",
        "    vertical_spacing=0.01,\n",
        "    horizontal_spacing=0.1,\n",
        ")\n",
        "\n",
        "results = {}\n",
        "avg_performance = {}\n",
        "\n",
        "colour_map = {\n",
        "    \"substring_matching\": plotly.colors.qualitative.Plotly[1],\n",
        "    \"llamaguard3\": plotly.colors.qualitative.Plotly[2],\n",
        "    \"harmbench\": plotly.colors.qualitative.Plotly[0],\n",
        "}\n",
        "categories = list(str(i) for i in range(0, 360, 10))\n",
        "categories.append(categories[0])\n",
        "\n",
        "for idx, model_id in enumerate(model_ids):\n",
        "\n",
        "    _, model_name = model_id.split(\"/\")\n",
        "    output_path = Path(data_path) / model_name / f\"{opt_mode}_{mode}\" / f\"beta_{beta_str}\"\n",
        "\n",
        "    if adaptive:\n",
        "        glob_pattern = f\"eval-mode_1*.json\"\n",
        "    else:\n",
        "        glob_pattern = \"eval-[!(mode)(perp)]*.json\"\n",
        "\n",
        "    for file in sorted(list(output_path.glob(glob_pattern))):\n",
        "        if adaptive:\n",
        "            metric = file.stem.split(\"-\")[2]\n",
        "        else:\n",
        "            metric = file.stem.split(\"-\")[1]\n",
        "\n",
        "        if metric == \"llmjudge\":\n",
        "            continue\n",
        "        if metric not in colour_map:\n",
        "            colour_map[metric] = plotly.colors.qualitative.Plotly[len(colour_map)]\n",
        "\n",
        "        with open(file, \"r\") as f:\n",
        "            eval_data = json.load(f)\n",
        "        print(file)\n",
        "        baseline = eval_data[\"baseline\"]\n",
        "        if isinstance(baseline, list):\n",
        "            baseline = np.mean(baseline)\n",
        "        fig.add_trace(\n",
        "            go.Scatterpolar(\n",
        "                r=[\n",
        "                    (\n",
        "                        1 - baseline\n",
        "                        if metric in [\"substring_matching\", \"llmjudge\"]\n",
        "                        else baseline\n",
        "                    )\n",
        "                    for _ in range(len(categories))\n",
        "                ],\n",
        "                theta=categories,\n",
        "                name=\"baseline\",\n",
        "                line=dict(width=2, color=colour_map[metric], dash=\"dot\"),\n",
        "                # legendgroup=\"legend\"\n",
        "                mode=\"lines\",\n",
        "                opacity=0.5,\n",
        "                showlegend=False,\n",
        "            ),\n",
        "            row=idx // num_col + 1,\n",
        "            col=idx % num_col + 1,\n",
        "        )\n",
        "\n",
        "        print(eval_data.keys())\n",
        "        chosen_plane_id = [\n",
        "            s\n",
        "            for s in eval_data.keys()\n",
        "            if \"dir_random\" not in s and DIR_ID_MAP[model_id] in s\n",
        "        ]\n",
        "        if not chosen_plane_id:\n",
        "            continue\n",
        "        print(model_id, chosen_plane_id)\n",
        "        chosen_plane_id = chosen_plane_id[0]\n",
        "\n",
        "        # Track which angle is best:\n",
        "        if metric not in results.keys():\n",
        "            results[metric] = {}\n",
        "            avg_performance[metric] = {}\n",
        "        index, score = None, -900\n",
        "        for key in eval_data[chosen_plane_id].keys():\n",
        "            cur_score = np.mean(eval_data[chosen_plane_id][key])\n",
        "            if cur_score > score:\n",
        "                index = key\n",
        "                score = cur_score\n",
        "            if model_id not in avg_performance[metric].keys():\n",
        "                avg_performance[metric][model_id] = {}\n",
        "            avg_performance[metric][model_id][key] = cur_score\n",
        "        results[metric][model_id] = {'angle': index, 'score': score}\n",
        "        values = [eval_data[chosen_plane_id][cat] for cat in categories]\n",
        "\n",
        "        # if metric == \"llmjudge\":\n",
        "        #     values = [[0.5 if v == 0.75 else v for v in val] for val in values]\n",
        "        values = [np.mean(val) if isinstance(val, list) else val for val in values]\n",
        "        values.append(values[0])\n",
        "        fig.add_trace(\n",
        "            go.Scatterpolar(\n",
        "                r=(\n",
        "                    [1 - v for v in values]\n",
        "                    if metric in [\"substring_matching\", \"llmjudge\"]\n",
        "                    else values\n",
        "                ),\n",
        "                theta=categories,\n",
        "                legend=f\"legend{idx if idx > 0 else ''}\",\n",
        "                # fill=\"toself\",\n",
        "                name=metric,\n",
        "                line=dict(width=2, color=colour_map[metric]),\n",
        "                # legendgroup=\"legend\"\n",
        "                mode=\"lines\",\n",
        "                # opacity=0.5,\n",
        "                showlegend=idx == 0,\n",
        "            ),\n",
        "            row=idx // num_col + 1,\n",
        "            col=idx % num_col + 1,\n",
        "        )\n",
        "\n",
        "    fig.add_trace(\n",
        "        go.Scatterpolar(\n",
        "            r=[1.02],\n",
        "            # theta=[\"0\"],\n",
        "            name=\"feature direction\",\n",
        "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
        "            mode=\"markers\",\n",
        "            showlegend=idx == 0,\n",
        "        ),\n",
        "        row=idx // num_col + 1,\n",
        "        col=idx % num_col + 1,\n",
        "    )\n",
        "\n",
        "    # fig.add_trace(\n",
        "    #     go.Scatterpolar(\n",
        "    #         r=[1],\n",
        "    #         theta=['90'],\n",
        "    #         name=\"1st PC direction\",\n",
        "    #         marker=dict(size=10, symbol=2, color=plotly.colors.qualitative.Plotly[len(colour_map) + 1]),\n",
        "    #         mode=\"markers\",\n",
        "    #         showlegend=idx == 0,\n",
        "    #     ),\n",
        "    #     row=idx // num_col + 1,\n",
        "    #     col=idx % num_col + 1,\n",
        "    # )\n",
        "\n",
        "for i in range(len(model_ids) + 1):\n",
        "    polar_key = f'polar{i if i > 0 else \"\"}'\n",
        "    fig.update_layout(\n",
        "        {\n",
        "            polar_key: dict(\n",
        "                radialaxis=dict(\n",
        "                    visible=True,\n",
        "                    dtick=0.2,\n",
        "                    tickfont=dict(size=20),\n",
        "                ),\n",
        "                angularaxis=dict(\n",
        "                    # direction=\"clockwise\",\n",
        "                    # rotation=0,  # 0 degrees at the right (East)\n",
        "                    # period=360,\n",
        "                    # tickmode=\"array\",\n",
        "                    tickvals=list(\n",
        "                        range(0, 360, 10)\n",
        "                    ),  # Show all degrees in 10 intervals\n",
        "                    ticktext=[\n",
        "                        f\"{i}°\" for i in range(0, 360, 10)\n",
        "                    ],  # Show all tick labels\n",
        "                    tickfont=dict(size=18),  # Smaller font to fit all labels\n",
        "                    dtick=10,\n",
        "                ),\n",
        "            )\n",
        "        }\n",
        "    )\n",
        "\n",
        "\n",
        "# Global layout settings\n",
        "fig.update_layout(\n",
        "    height=2000,\n",
        "    width=1200,\n",
        "    # title_text=\"Angular Steering Effects on Model Performance\",\n",
        "    showlegend=True,\n",
        "    legend=dict(\n",
        "        orientation=\"h\",\n",
        "        # yanchor=\"top\",\n",
        "        y=0.0,\n",
        "        xanchor=\"center\",\n",
        "        x=0.5,\n",
        "        # entrywidth=0,\n",
        "        font=dict(size=30),\n",
        "    ),\n",
        "    margin=dict(l=50, r=50, t=20, b=0),\n",
        ")\n",
        "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
        "\n",
        "fig.show()\n",
        "\n",
        "if adaptive:\n",
        "    output_name = f\"eval_adaptive-harmness-all_models-vertical\"\n",
        "else:\n",
        "    output_name = f\"eval-harmness-all_models-vertical\"\n",
        "\n",
        "fig.write_image(\n",
        "    visualization_dir / f\"{output_name}.pdf\",\n",
        "    width=1200,\n",
        "    height=2000,\n",
        "    scale=5,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rTMrsCroYY3e"
      },
      "outputs": [],
      "source": [
        "avg_angle = {\"avg_of_2\": {}}\n",
        "metric = avg_performance.keys()\n",
        "model_ids = avg_performance[\"harmbench\"].keys()\n",
        "angles = avg_performance[\"harmbench\"]['Qwen/Qwen2.5-3B-Instruct'].keys()\n",
        "for model in model_ids:\n",
        "    best, idx = -999, None\n",
        "    for angle in angles:\n",
        "        score = 0.5 * (avg_performance[\"harmbench\"][model][angle] + avg_performance[\"llamaguard3\"][model][angle])\n",
        "        if score > best:\n",
        "            best = score\n",
        "            idx = angle\n",
        "    avg_angle[\"avg_of_2\"][model] = {\"angle\": idx, \"score\": best}\n",
        "avg_angle\n",
        "\n",
        "harmbench_angle = {'harmbench_angle_on_lg':{}}\n",
        "model_ids = avg_performance[\"harmbench\"].keys()\n",
        "for model in model_ids:\n",
        "    hb_angle = results[\"harmbench\"][model][\"angle\"]\n",
        "    hb_lg_performance = avg_performance[\"llamaguard3\"][model][hb_angle]\n",
        "    harmbench_angle[\"harmbench_angle_on_lg\"][model] = {\"angle\": hb_angle, \"score\": hb_lg_performance}\n",
        "\n",
        "llamaguard_angle = {'llamaguard_angle_on_hb':{}}\n",
        "model_ids = avg_performance[\"llamaguard3\"].keys()\n",
        "for model in model_ids:\n",
        "    lg_angle = results[\"llamaguard3\"][model][\"angle\"]\n",
        "    lg_hb_performance = avg_performance[\"harmbench\"][model][lg_angle]\n",
        "    llamaguard_angle[\"llamaguard_angle_on_hb\"][model] = {\"angle\": lg_angle, \"score\": lg_hb_performance }\n",
        "\n",
        "llamaguard_angle"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nBPQWtgrYY3e"
      },
      "outputs": [],
      "source": [
        "# Convert nested dict into dataframe\n",
        "import pandas as pd\n",
        "\n",
        "\n",
        "df = pd.concat(\n",
        "    {task: pd.DataFrame(task_data).T for task, task_data in results.items()},\n",
        "    axis=1\n",
        ")\n",
        "\n",
        "df_avg = pd.concat(\n",
        "    {task: pd.DataFrame(task_data).T for task, task_data in avg_angle.items()},\n",
        "    axis=1)\n",
        "\n",
        "df_hb = pd.concat(\n",
        "    {task: pd.DataFrame(task_data).T for task, task_data in harmbench_angle.items()},\n",
        "    axis=1)\n",
        "\n",
        "df_lg = pd.concat(\n",
        "    {task: pd.DataFrame(task_data).T for task, task_data in llamaguard_angle.items()},\n",
        "    axis=1)\n",
        "\n",
        "df = pd.concat([df, df_avg, df_hb, df_lg], axis = 1)\n",
        "\n",
        "if 'baseline' in mode:\n",
        "    df.insert(0, \"mode\", f\"{opt_mode}_{mode}\")\n",
        "    csv_name = f\"results_csv/{opt_mode}_{mode}.csv\"\n",
        "else:\n",
        "    df.insert(0, \"mode\", f\"{opt_mode}_{mode}_beta_{beta_str}\")\n",
        "    csv_name = f\"results_csv/{opt_mode}_{mode}_beta_{beta_str}.csv\"\n",
        "\n",
        "df.to_csv(csv_name, index = True)\n",
        "df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "plNLcMmwYY3f"
      },
      "source": [
        "## LLM-as-a-judge classification results\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fd34C7iwYY3f"
      },
      "outputs": [],
      "source": [
        "# model_ids = [\n",
        "#     \"Qwen/Qwen2.5-3B-Instruct\",\n",
        "#     \"Qwen/Qwen2.5-7B-Instruct\",\n",
        "#     \"Qwen/Qwen2.5-14B-Instruct\",\n",
        "#     \"meta-llama/Llama-3.2-3B-Instruct\",\n",
        "#     \"meta-llama/Llama-3.1-8B-Instruct\",\n",
        "#     \"google/gemma-2-9b-it\",\n",
        "# ]\n",
        "\n",
        "# language = \"en\"\n",
        "# data_type = \"harmful\"\n",
        "# num_row = 3\n",
        "# num_col = len(model_ids) // num_row\n",
        "# data_path = \"output/\"\n",
        "# adaptive = True\n",
        "\n",
        "\n",
        "# colour_map = {\n",
        "#     \"direct\": \"red\",\n",
        "#     \"indirect\": \"orange\",\n",
        "#     \"redirect\": \"teal\",\n",
        "#     \"refusal\": \"green\",\n",
        "# }\n",
        "# score2label = {\n",
        "#     0: \"refusal\",\n",
        "#     0.25: \"redirect\",\n",
        "#     0.5: \"indirect\",\n",
        "#     0.75: \"indirect\",\n",
        "#     1: \"direct\",\n",
        "# }\n",
        "\n",
        "# fig = make_subplots(\n",
        "#     rows=num_row,\n",
        "#     cols=num_col,\n",
        "#     specs=[[{\"type\": \"polar\"}] * num_col] * num_row,\n",
        "#     subplot_titles=[model_id.split(\"/\")[1] for model_id in model_ids],\n",
        "#     vertical_spacing=0.01,\n",
        "#     horizontal_spacing=0.1,\n",
        "# )\n",
        "\n",
        "# for idx, model_id in enumerate(model_ids):\n",
        "#     _, model_name = model_id.split(\"/\")\n",
        "#     output_path = Path(data_path) / model_name\n",
        "\n",
        "#     if adaptive:\n",
        "#         file = list(output_path.glob(\"eval-mode_1-llmjudge-*.json\"))[0]\n",
        "#     else:\n",
        "#         file = list(output_path.glob(\"eval-llmjudge-*.json\"))[0]\n",
        "#     print(file)\n",
        "#     with open(file, \"r\") as f:\n",
        "#         eval_data = json.load(f)\n",
        "\n",
        "#     chosen_plane_id = [s for s in eval_data.keys() if DIR_ID_MAP[model_id] in s and \"dir_random\" not in s]\n",
        "#     if len(chosen_plane_id) != 1:\n",
        "#         print(f\"Skipping {model_id} due to {len(chosen_plane_id)} planes IDs found: {chosen_plane_id}\")\n",
        "#         continue\n",
        "\n",
        "#     chosen_plane_id = chosen_plane_id[0]\n",
        "\n",
        "\n",
        "#     print(model_id, chosen_plane_id)\n",
        "#     categories = list(str(i) for i in range(0, 360, 10))\n",
        "#     categories.append(categories[0])\n",
        "#     values = [eval_data[chosen_plane_id][cat] for cat in categories]\n",
        "\n",
        "#     scores = [1, 0.75, 0.25, 0]\n",
        "#     for v in scores:\n",
        "#         r = [vals.count(v) / len(vals) for vals in values]\n",
        "#         fig.add_trace(\n",
        "#             go.Barpolar(\n",
        "#                 r=r,\n",
        "#                 # theta=list(str(i) for i in range(0, 360, 10)),\n",
        "#                 # marker_color=[colour_map[score2label[v]] for v in scores],\n",
        "#                 marker_color=colour_map[score2label[v]],\n",
        "#                 name=score2label[v],\n",
        "#                 showlegend=idx == 0,\n",
        "#                 marker=dict(\n",
        "#                     line=dict(\n",
        "#                         width=0.0,\n",
        "#                         # color=\"rgba(0,0,0,0)\"\n",
        "#                     )\n",
        "#                 ),\n",
        "#             ),\n",
        "#             row=idx // num_col + 1,\n",
        "#             col=idx % num_col + 1,\n",
        "#         )\n",
        "\n",
        "#     fig.add_trace(\n",
        "#         go.Scatterpolar(\n",
        "#             r=[1.02],\n",
        "#             theta=[\"0\"],\n",
        "#             name=\"feature direction\",\n",
        "#             marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
        "#             mode=\"markers\",\n",
        "#             showlegend=idx == 0,\n",
        "#         ),\n",
        "#         row=idx // num_col + 1,\n",
        "#         col=idx % num_col + 1,\n",
        "#     )\n",
        "\n",
        "\n",
        "# polar_config = dict(\n",
        "#     bargap=0,\n",
        "#     hole=0,\n",
        "#     angularaxis=dict(tickfont=dict(size=16), dtick=10),\n",
        "#     radialaxis=dict(\n",
        "#         showticklabels=False,\n",
        "#         showline=False,\n",
        "#     ),\n",
        "# )\n",
        "# fig.update_layout(\n",
        "#     polar=polar_config,\n",
        "#     polar2=polar_config,\n",
        "#     polar3=polar_config,\n",
        "#     polar4=polar_config,\n",
        "#     polar5=polar_config,\n",
        "#     polar6=polar_config,\n",
        "#     showlegend=True,\n",
        "#     legend=dict(\n",
        "#         orientation=\"h\",\n",
        "#         # yanchor=\"top\",\n",
        "#         y=0.0,\n",
        "#         xanchor=\"center\",\n",
        "#         x=0.5,\n",
        "#         font=dict(size=30),\n",
        "#     ),\n",
        "#     height=2000,\n",
        "#     width=1200,\n",
        "#     margin=dict(l=50, r=50, t=20, b=0),\n",
        "# )\n",
        "\n",
        "# fig.update_annotations(font=dict(size=36), yshift=-20)\n",
        "\n",
        "# fig.show()\n",
        "\n",
        "# fig.write_image(\n",
        "#     visualization_dir / \"eval_adaptive-llmjudge-all_models-vertical.pdf\",\n",
        "#     width=1200,\n",
        "#     height=2000,\n",
        "#     scale=5,\n",
        "# )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d6dexKUZYY3f"
      },
      "source": [
        "# Adaptive Angular Steering on tinyBenchmark\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OytpGOyOYY3f"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import os\n",
        "import numpy as np\n",
        "import plotly.graph_objects as go\n",
        "import plotly.express as px\n",
        "from plotly.subplots import make_subplots\n",
        "import re\n",
        "from glob import glob\n",
        "from pathlib import Path\n",
        "from collections import defaultdict\n",
        "\n",
        "# Base directory for benchmarks\n",
        "base_dir = f\"benchmarks/{dir_id}/\"\n",
        "\n",
        "# Dictionary to map tasks to their primary metric(s)\n",
        "task_to_metrics = {\n",
        "    \"tinyArc\": [\"acc_norm,none\"],\n",
        "    \"tinyGSM8k\": [\n",
        "        \"exact_match,strict-match\",\n",
        "        \"exact_match,flexible-extract\",\n",
        "    ],\n",
        "    \"tinyHellaswag\": [\"acc_norm,none\"],\n",
        "    \"tinyMMLU\": [\"acc_norm,none\"],\n",
        "    \"tinyTruthfulQA\": [\"acc,none\"],\n",
        "    \"tinyWinogrande\": [\"acc_norm,none\"],\n",
        "}\n",
        "\n",
        "\n",
        "models = [\n",
        "    \"Qwen2.5-3B-Instruct\",\n",
        "    \"Qwen2.5-7B-Instruct\",\n",
        "    \"Qwen2.5-14B-Instruct\",\n",
        "    \"Llama-3.2-3B-Instruct\",\n",
        "    \"Llama-3.1-8B-Instruct\",\n",
        "    \"gemma-2-9b-it\",\n",
        "]\n",
        "\n",
        "# Collect data for all models and tasks\n",
        "results = defaultdict(lambda: defaultdict(dict))\n",
        "baselines = defaultdict(dict)\n",
        "\n",
        "# Debug tracking\n",
        "found_metrics = defaultdict(set)\n",
        "\n",
        "for task_dir in os.listdir(base_dir):\n",
        "    task_path = os.path.join(base_dir, task_dir)\n",
        "    if not os.path.isdir(task_path):\n",
        "        continue\n",
        "\n",
        "    for model_dir in os.listdir(task_path):\n",
        "        if model_dir not in models:\n",
        "            continue\n",
        "\n",
        "        model_path = os.path.join(task_path, model_dir)\n",
        "        if not os.path.isdir(model_path):\n",
        "            continue\n",
        "\n",
        "        # Get metrics for this task\n",
        "        metrics = task_to_metrics.get(task_dir, [])\n",
        "        if not metrics:\n",
        "            continue\n",
        "\n",
        "        # Find all results files for this model and task\n",
        "        for degree in list(range(0, 360, 10)) + [\"none\"]:\n",
        "            pattern = os.path.join(\n",
        "                model_path, f\"adaptive_{degree}\", \"*\", \"results_*.json\"\n",
        "            )\n",
        "            result_files = glob(pattern)\n",
        "\n",
        "            if not result_files:\n",
        "                continue\n",
        "\n",
        "            # Since the result files are named by date, sort them to get the latest\n",
        "            result_fiile = sorted(result_files)\n",
        "            # Use the last result file found for each degree\n",
        "            result_file = result_files[-1]\n",
        "\n",
        "            try:\n",
        "                with open(result_file, \"r\") as f:\n",
        "                    data = json.load(f)\n",
        "\n",
        "                # For each metric defined for this task\n",
        "                for metric in metrics:\n",
        "                    # Create a display name for the metric\n",
        "                    if task_dir == \"tinyGSM8k\":\n",
        "                        # For GSM8k, include the metric type in the display name\n",
        "                        # Extract strict-match or loose-match\n",
        "                        metric_type = (metric.split(\",\")[1]).split(\"-\")[0]\n",
        "                        display_name = f\"{task_dir} ({metric_type})\"\n",
        "                    else:\n",
        "                        display_name = task_dir\n",
        "\n",
        "                    if (\n",
        "                        task_dir in data.get(\"results\", {})\n",
        "                        and metric in data[\"results\"][task_dir]\n",
        "                    ):\n",
        "                        score = data[\"results\"][task_dir][metric]\n",
        "                        found_metrics[task_dir].add(\n",
        "                            metric\n",
        "                        )  # Track which metrics were found\n",
        "\n",
        "                        if degree == \"none\":\n",
        "                            baselines[model_dir][display_name] = score\n",
        "                        else:\n",
        "                            results[model_dir][display_name][int(degree)] = score\n",
        "            except Exception as e:\n",
        "                print(f\"Error processing {result_file}: {e}\")\n",
        "\n",
        "# Print metrics found for debugging\n",
        "print(\"Found metrics by task:\")\n",
        "for task, metrics in found_metrics.items():\n",
        "    print(f\"{task}: {metrics}\")\n",
        "\n",
        "# Create subplots - arrange based on number of models\n",
        "n_models = len(models)\n",
        "n_cols = min(2, n_models)\n",
        "n_rows = (n_models + n_cols - 1) // n_cols\n",
        "# n_cols = 1\n",
        "# n_rows = n_models\n",
        "\n",
        "fig = make_subplots(\n",
        "    rows=n_rows,\n",
        "    cols=n_cols,\n",
        "    specs=[[{\"type\": \"polar\"} for _ in range(n_cols)] for _ in range(n_rows)],\n",
        "    subplot_titles=models,\n",
        "    vertical_spacing=0.01,\n",
        "    horizontal_spacing=0.1,\n",
        ")\n",
        "\n",
        "# Add reference baseline trace only once (for legend)\n",
        "fig.add_trace(\n",
        "    go.Scatterpolar(\n",
        "        r=[0],\n",
        "        theta=[0],\n",
        "        name=\"baseline\",\n",
        "        line=dict(width=2, color=\"black\", dash=\"dot\"),\n",
        "        mode=\"lines\",\n",
        "        opacity=0.5,\n",
        "        showlegend=True,\n",
        "    ),\n",
        "    row=1,\n",
        "    col=1,\n",
        ")\n",
        "\n",
        "# Use Plotly's built-in color palette\n",
        "colors = px.colors.qualitative.Plotly\n",
        "\n",
        "# Plot data for each model\n",
        "for model_idx, model_name in enumerate(models):\n",
        "    row = model_idx // n_cols + 1\n",
        "    col = model_idx % n_cols + 1\n",
        "\n",
        "    if model_name not in results:\n",
        "        continue\n",
        "\n",
        "    # Add reference arrow at 0 degrees\n",
        "    fig.add_trace(\n",
        "        go.Scatterpolar(\n",
        "            r=[1.02],\n",
        "            theta=[0],\n",
        "            name=\"feature direction\",\n",
        "            marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
        "            mode=\"markers\",\n",
        "            showlegend=model_idx == 0,  # Only show in legend for first model\n",
        "        ),\n",
        "        row=row,\n",
        "        col=col,\n",
        "    )\n",
        "\n",
        "    # Track which tasks we've displayed for this model\n",
        "    task_color_index = 0\n",
        "\n",
        "    # Sort the tasks for consistent color assignment\n",
        "    sorted_tasks = sorted(results[model_name].keys())\n",
        "\n",
        "    for task_display_name in sorted_tasks:\n",
        "        task_color = colors[task_color_index % len(colors)]\n",
        "        task_color_index += 1\n",
        "\n",
        "        # Prepare angle and score data\n",
        "        angles = sorted(results[model_name][task_display_name].keys())\n",
        "        scores = [results[model_name][task_display_name][angle] for angle in angles]\n",
        "\n",
        "        # Close the loop for a complete polar plot\n",
        "        angles_closed = np.append(angles, angles[0])\n",
        "        scores_closed = np.append(scores, scores[0])\n",
        "\n",
        "        # Plot the task scores with solid lines\n",
        "        fig.add_trace(\n",
        "            go.Scatterpolar(\n",
        "                r=scores_closed,\n",
        "                theta=angles_closed,\n",
        "                name=task_display_name,\n",
        "                line=dict(width=2, color=task_color),\n",
        "                mode=\"lines\",\n",
        "                showlegend=model_idx == 0,  # Only show in legend for first model\n",
        "            ),\n",
        "            row=row,\n",
        "            col=col,\n",
        "        )\n",
        "\n",
        "        # Add baseline as dashed circular line if available\n",
        "        if model_name in baselines and task_display_name in baselines[model_name]:\n",
        "            baseline_score = baselines[model_name][task_display_name]\n",
        "            baseline_theta = np.linspace(0, 360, 361)\n",
        "\n",
        "            # Use dash line for baselines but don't add to legend\n",
        "            fig.add_trace(\n",
        "                go.Scatterpolar(\n",
        "                    r=[baseline_score] * len(baseline_theta),\n",
        "                    theta=baseline_theta,\n",
        "                    line=dict(width=1.5, color=task_color, dash=\"dot\"),\n",
        "                    mode=\"lines\",\n",
        "                    opacity=0.5,\n",
        "                    showlegend=False,  # Don't add individual baselines to legend\n",
        "                ),\n",
        "                row=row,\n",
        "                col=col,\n",
        "            )\n",
        "\n",
        "# Update layout for each subplot\n",
        "for i in range(1, n_models + 1):\n",
        "    row = (i - 1) // n_cols + 1\n",
        "    col = (i - 1) % n_cols + 1\n",
        "\n",
        "    # Get the corresponding polar subplot key\n",
        "    polar_key = f'polar{i if i > 1 else \"\"}'\n",
        "\n",
        "    # Set consistent range and angle markings across plots\n",
        "    fig.update_layout(\n",
        "        {\n",
        "            polar_key: dict(\n",
        "                radialaxis=dict(\n",
        "                    visible=True,\n",
        "                    range=[0.0, 1.02],\n",
        "                    # dtick=0.2,\n",
        "                    # tickvals=[0.3, 0.5, 0.7, 0.9, 1],\n",
        "                    tickfont=dict(size=22),\n",
        "                ),\n",
        "                angularaxis=dict(\n",
        "                    # direction=\"clockwise\",\n",
        "                    # rotation=0,  # 0 degrees at the right (East)\n",
        "                    # period=360,\n",
        "                    # tickmode=\"array\",\n",
        "                    # tickvals=list(\n",
        "                    #     range(0, 360, 10)\n",
        "                    # ),  # Show all degrees in 10 intervals\n",
        "                    # ticktext=[\n",
        "                    #     f\"{i}°\" for i in range(0, 360, 10)\n",
        "                    # ],  # Show all tick labels\n",
        "                    tickfont=dict(size=16),  # Smaller font to fit all labels\n",
        "                    dtick=10,\n",
        "                ),\n",
        "            )\n",
        "        }\n",
        "    )\n",
        "\n",
        "# Global layout settings\n",
        "fig.update_layout(\n",
        "    height=2000,\n",
        "    width=1200,\n",
        "    # title_text=\"Angular Steering Effects on Model Performance\",\n",
        "    showlegend=True,\n",
        "    legend=dict(\n",
        "        orientation=\"h\",\n",
        "        # yanchor=\"top\",\n",
        "        y=0.0,\n",
        "        xanchor=\"center\",\n",
        "        x=0.5,\n",
        "        # entrywidth=0,\n",
        "        font=dict(size=30),\n",
        "    ),\n",
        "    margin=dict(l=50, r=50, t=20, b=0),\n",
        ")\n",
        "fig.update_annotations(font=dict(size=36), yshift=-20)\n",
        "\n",
        "fig.show()\n",
        "\n",
        "fig.write_image(\n",
        "    visualization_dir / \"eval_adaptive-tinyBenchmark-all_models-vertical.pdf\",\n",
        "    width=1200,\n",
        "    height=2000,\n",
        "    scale=5,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dFG9dse8YY3f"
      },
      "source": [
        "## Perplexity scores\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EKNEZ-jYYY3g"
      },
      "outputs": [],
      "source": [
        "# import json\n",
        "# import numpy as np\n",
        "# import plotly.graph_objects as go\n",
        "# import plotly.express as px\n",
        "# from pathlib import Path\n",
        "# import glob\n",
        "\n",
        "\n",
        "# def filter_perplexity(perplexity_data, threshold=100):\n",
        "#     \"\"\"\n",
        "#     Filter perplexity data to remove outliers.\n",
        "#     E.g. in `qwen2.5-3B/rotated/max_sim-pca/290/34` there is a perplexity of 2.8e6\n",
        "#     because the generation is empty\n",
        "#     \"\"\"\n",
        "#     return [p for p in perplexity_data if p < threshold]\n",
        "\n",
        "\n",
        "# MODELS = [\n",
        "#     \"Qwen/Qwen2.5-3B-Instruct\",\n",
        "#     \"Qwen/Qwen2.5-7B-Instruct\",\n",
        "#     \"Qwen/Qwen2.5-14B-Instruct\",\n",
        "#     \"meta-llama/Llama-3.2-3B-Instruct\",\n",
        "#     \"meta-llama/Llama-3.1-8B-Instruct\",\n",
        "#     \"google/gemma-2-9b-it\",\n",
        "# ]\n",
        "\n",
        "# # Base output directory\n",
        "# output_dir = Path(\"output\")\n",
        "\n",
        "# generation_only = False\n",
        "\n",
        "# # Find all perplexity evaluation files\n",
        "# if generation_only:\n",
        "#     perplexity_files = glob.glob(\n",
        "#         str(output_dir / \"*\" / \"eval-perplexity-generation_only.json\")\n",
        "#     )\n",
        "# else:\n",
        "#     perplexity_files = glob.glob(str(output_dir / \"*\" / \"eval-perplexity.json\"))\n",
        "\n",
        "# # Create figure with subplots based on number of models\n",
        "# n_models = len(MODELS)\n",
        "# n_cols = min(2, n_models)\n",
        "# n_rows = (n_models + n_cols - 1) // n_cols\n",
        "# # n_cols = 1\n",
        "# # n_rows = n_models\n",
        "\n",
        "# # Colors for different modes\n",
        "# colors = px.colors.qualitative.Plotly\n",
        "\n",
        "# # Calculate model-specific max perplexity values\n",
        "# all_perplexity_data = {}\n",
        "\n",
        "# for perplexity_file in perplexity_files:\n",
        "#     model_name = Path(perplexity_file).parent.name\n",
        "#     with open(perplexity_file, \"r\") as f:\n",
        "#         data = json.load(f)\n",
        "#         all_perplexity_data[model_name] = data\n",
        "\n",
        "#         # Calculate max perplexity for this specific model\n",
        "#         model_max = 0\n",
        "\n",
        "#         # Check baseline max\n",
        "#         if \"harmful-en-baseline\" in data:\n",
        "#             model_max = max(model_max, max(data[\"harmful-en-baseline\"]))\n",
        "\n",
        "# if n_models > 1:\n",
        "#     from plotly.subplots import make_subplots\n",
        "\n",
        "#     fig = make_subplots(\n",
        "#         rows=n_rows,\n",
        "#         cols=n_cols,\n",
        "#         specs=[[{\"type\": \"polar\"} for _ in range(n_cols)] for _ in range(n_rows)],\n",
        "#         subplot_titles=[m.split('/')[-1] for m in MODELS],\n",
        "#         vertical_spacing=0.01,\n",
        "#         horizontal_spacing=0.1,\n",
        "#     )\n",
        "# else:\n",
        "#     fig = go.Figure()\n",
        "\n",
        "# # Add reference baseline trace only once for legend\n",
        "# fig.add_trace(\n",
        "#     go.Scatterpolar(\n",
        "#         r=[0],\n",
        "#         theta=[0],\n",
        "#         name=\"no steering\",\n",
        "#         line=dict(width=2, color=\"black\", dash=\"dot\"),\n",
        "#         mode=\"lines\",\n",
        "#         opacity=0.5,\n",
        "#         showlegend=True,\n",
        "#     ),\n",
        "#     row=1 if n_models > 1 else None,\n",
        "#     col=1 if n_models > 1 else None,\n",
        "# )\n",
        "\n",
        "# # Process each perplexity file\n",
        "# for model_idx, model_id in enumerate(MODELS):\n",
        "#     chosen_direction = DIR_ID_MAP[model_id]\n",
        "#     print(model_id)\n",
        "#     # print(chosen_direction)\n",
        "\n",
        "\n",
        "#     model_name = model_id.split(\"/\")[-1]\n",
        "#     perplexity_data = all_perplexity_data[model_name]\n",
        "\n",
        "#     row = model_idx // n_cols + 1 if n_models > 1 else None\n",
        "#     col = model_idx % n_cols + 1 if n_models > 1 else None\n",
        "\n",
        "#     # Extract modes (rotated, adaptive, etc.)\n",
        "#     modes = []\n",
        "#     for key in perplexity_data.keys():\n",
        "#         if chosen_direction not in key:\n",
        "#             continue\n",
        "#         print(key)\n",
        "#         if key.startswith(\"harmful-en-dir\") and not key.endswith(\"baseline\"):\n",
        "#             mode = \"non-adaptive\" if \"rotated\" in key else \"adaptive\"\n",
        "#             modes.append((key, mode))\n",
        "#     modes.sort()\n",
        "\n",
        "#     # Process each mode\n",
        "#     for mode_idx, (mode_key, mode_name) in enumerate(modes):\n",
        "#         angles = []\n",
        "#         mean_perplexities = []\n",
        "\n",
        "#         # Calculate mean perplexity for each angle\n",
        "#         for angle_str, perplexities in perplexity_data[mode_key].items():\n",
        "#             if angle_str.isdigit() and perplexities:\n",
        "#                 perplexities = filter_perplexity(perplexities)\n",
        "#                 angle = int(angle_str)\n",
        "#                 mean_perp = np.mean(perplexities)\n",
        "#                 angles.append(angle)\n",
        "#                 mean_perplexities.append(mean_perp)\n",
        "\n",
        "#         # Sort by angle\n",
        "#         sorted_indices = np.argsort(angles)\n",
        "#         angles = np.array(angles)[sorted_indices]\n",
        "#         mean_perplexities = np.array(mean_perplexities)[sorted_indices]\n",
        "\n",
        "#         # Add one more point to close the loop\n",
        "#         if len(angles) > 1:\n",
        "#             angles_closed = np.append(angles, angles[0])\n",
        "#             perplexities_closed = np.append(mean_perplexities, mean_perplexities[0])\n",
        "\n",
        "#             # Add to plot\n",
        "#             fig.add_trace(\n",
        "#                 go.Scatterpolar(\n",
        "#                     r=perplexities_closed,\n",
        "#                     theta=angles_closed,\n",
        "#                     name=f\"{mode_name}\",\n",
        "#                     line=dict(width=2, color=colors[mode_idx % len(colors)]),\n",
        "#                     mode=\"lines\",\n",
        "#                     showlegend=model_idx == 0,  # Only show in legend for first model\n",
        "#                 ),\n",
        "#                 row=row,\n",
        "#                 col=col,\n",
        "#             )\n",
        "\n",
        "#     max_r = max(mean_perplexities)\n",
        "\n",
        "#     # Add reference star at 0 degrees\n",
        "#     fig.add_trace(\n",
        "#         go.Scatterpolar(\n",
        "#             r=[max_r * 1.02],\n",
        "#             theta=[0],\n",
        "#             name=\"feature direction\",\n",
        "#             marker=dict(size=20, symbol=\"arrow-right\", color=\"black\"),\n",
        "#             mode=\"markers\",\n",
        "#             showlegend=model_idx == 0,  # Only show in legend for first model\n",
        "#             legend=\"legend\",\n",
        "#         ),\n",
        "#         row=row,\n",
        "#         col=col,\n",
        "#     )\n",
        "\n",
        "#     # Add baseline if available\n",
        "#     if \"harmful-en-baseline\" in perplexity_data:\n",
        "#         baseline_perplexity = np.mean(perplexity_data[\"harmful-en-baseline\"])\n",
        "#         max_r = max(max_r, baseline_perplexity)\n",
        "#         baseline_theta = np.linspace(0, 360, 100)\n",
        "\n",
        "#         fig.add_trace(\n",
        "#             go.Scatterpolar(\n",
        "#                 r=[baseline_perplexity] * len(baseline_theta),\n",
        "#                 theta=baseline_theta,\n",
        "#                 name=\"Baseline\",\n",
        "#                 line=dict(width=1.5, color=\"black\", dash=\"dot\"),\n",
        "#                 mode=\"lines\",\n",
        "#                 opacity=0.5,\n",
        "#                 showlegend=False,  # Don't add individual baselines to legend\n",
        "#                 legend=\"legend\",\n",
        "#             ),\n",
        "#             row=row,\n",
        "#             col=col,\n",
        "#         )\n",
        "\n",
        "\n",
        "# # Update layout with model-specific ranges\n",
        "# if n_models > 1:\n",
        "#     for i in range(1, n_models + 1):\n",
        "#         row = (i - 1) // n_cols + 1\n",
        "#         col = (i - 1) % n_cols + 1\n",
        "#         polar_key = f'polar{i if i > 1 else \"\"}'\n",
        "\n",
        "#         fig.update_layout(\n",
        "#             {\n",
        "#                 polar_key: dict(\n",
        "#                     radialaxis=dict(\n",
        "#                         visible=True,\n",
        "#                         # range=[0.3, 1.02],\n",
        "#                         # dtick=0.2,\n",
        "#                         # tickvals=[0.3, 0.5, 0.7, 0.9, 1],\n",
        "#                         tickfont=dict(size=22),\n",
        "#                     ),\n",
        "#                     angularaxis=dict(\n",
        "#                         tickvals=list(range(0, 360, 10)),\n",
        "#                         ticktext=[f\"{i}°\" for i in range(0, 360, 10)],\n",
        "#                         tickfont=dict(size=16),\n",
        "#                     ),\n",
        "#                 )\n",
        "#             }\n",
        "#         )\n",
        "# else:\n",
        "#     fig.update_layout(\n",
        "#         polar=dict(\n",
        "#             radialaxis=dict(visible=True, title=\"Average Perplexity\"),\n",
        "#             angularaxis=dict(\n",
        "#                 tickvals=list(range(0, 360, 10)),\n",
        "#                 ticktext=[f\"{i}°\" for i in range(0, 360, 10)],\n",
        "#                 tickfont=dict(size=16),\n",
        "#             ),\n",
        "#         ),\n",
        "#     )\n",
        "\n",
        "# fig.update_layout(\n",
        "#     height=2000,\n",
        "#     width=1200,\n",
        "#     # title_text=\"Angular Steering Effects on Model Perplexity\",\n",
        "#     showlegend=True,\n",
        "#     # legend=dict(\n",
        "#     #     orientation=\"h\",\n",
        "#     #     yanchor=\"bottom\",\n",
        "#     #     y=-0.1 if n_models > 1 else -0.2,\n",
        "#     #     xanchor=\"center\",\n",
        "#     #     x=0.5,\n",
        "#     # ),\n",
        "#     legend=dict(\n",
        "#         orientation=\"h\",\n",
        "#         # yanchor=\"top\",\n",
        "#         y=0.0,\n",
        "#         xanchor=\"center\",\n",
        "#         x=0.5,\n",
        "#         font=dict(size=30),\n",
        "#         # entrywidth=100,\n",
        "#         entrywidthmode=\"pixels\",\n",
        "#     ),\n",
        "#     # legend2=dict(\n",
        "#     #     orientation=\"h\",\n",
        "#     #     # yanchor=\"top\",\n",
        "#     #     y=-0.01,\n",
        "#     #     xanchor=\"center\",\n",
        "#     #     x=0.5,\n",
        "#     #     font=dict(size=16),\n",
        "#     # ),\n",
        "#     margin=dict(l=50, r=50, t=20, b=140),\n",
        "# )\n",
        "# fig.update_annotations(font=dict(size=36), yshift=-20)\n",
        "\n",
        "# fig.show()\n",
        "\n",
        "# # save figure\n",
        "# if generation_only:\n",
        "#     output_name = \"eval-ppl-generation_only\"\n",
        "# else:\n",
        "#     output_name = \"eval-ppl\"\n",
        "# fig.write_image(\n",
        "#     visualization_dir / f\"{output_name}-vertical.pdf\",\n",
        "#     height=2000,\n",
        "#     width=1200,\n",
        "#     scale=5,\n",
        "# )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tTxU12TYY3g"
      },
      "source": [
        "## Geometric interpretation of Angular Steering\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iLKVvImaYY3g"
      },
      "outputs": [],
      "source": [
        "import plotly\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
        "\n",
        "# Define base vectors\n",
        "d_1stPC_3d = np.array([1.5, 0.0, 0.5])\n",
        "d_3d = np.array([0.5, 1.0, 0.0])\n",
        "d_3d = d_3d / np.linalg.norm(d_3d)\n",
        "\n",
        "alpha = 1.0\n",
        "h_add_3d = d_1stPC_3d + alpha * d_3d\n",
        "d_ortho_3d = d_1stPC_3d - np.dot(d_1stPC_3d, d_3d) * d_3d\n",
        "\n",
        "# Normalize versions\n",
        "h_norm = d_1stPC_3d / np.linalg.norm(d_1stPC_3d)\n",
        "h_add_norm = h_add_3d / np.linalg.norm(h_add_3d)\n",
        "h_ablate_norm = d_ortho_3d / np.linalg.norm(d_ortho_3d)\n",
        "\n",
        "# Plane that spans h and d\n",
        "extended_h = 1.2 * d_1stPC_3d\n",
        "extended_d = 1.2 * d_3d\n",
        "origin = np.array([0, 0, 0])\n",
        "corner1 = 1.2 * extended_h + 1.4 * extended_d\n",
        "corner2 = 1.2 * extended_h - 0.8 * extended_d\n",
        "corner3 = 0.5 * -extended_h - 0.8 * extended_d\n",
        "corner4 = 0.5 * -extended_h + 1.4 * extended_d\n",
        "plane_vertices = [corner1, corner2, corner3, corner4]\n",
        "\n",
        "# Create figure and axes\n",
        "fig = plt.figure(figsize=(14, 6))\n",
        "ax1 = fig.add_subplot(121, projection=\"3d\")\n",
        "ax2 = fig.add_subplot(122, projection=\"3d\")\n",
        "\n",
        "\n",
        "# Right angle marker\n",
        "def draw_right_angle_marker(ax, vec1, vec2, scale=0.2):\n",
        "    v1 = vec1 / np.linalg.norm(vec1)\n",
        "    v2 = vec2 / np.linalg.norm(vec2)\n",
        "    base = np.array([0, 0, 0])\n",
        "    corner = base + scale * v1\n",
        "    offset1 = corner + scale * v2\n",
        "    offset2 = offset1 - scale * v1\n",
        "    ax.plot(\n",
        "        [corner[0], offset1[0]],\n",
        "        [corner[1], offset1[1]],\n",
        "        [corner[2], offset1[2]],\n",
        "        color=\"gray\",\n",
        "    )\n",
        "    ax.plot(\n",
        "        [offset1[0], offset2[0]],\n",
        "        [offset1[1], offset2[1]],\n",
        "        [offset1[2], offset2[2]],\n",
        "        color=\"gray\",\n",
        "    )\n",
        "\n",
        "\n",
        "# Rotation arc\n",
        "def draw_arc(ax, v_from, v_to, color):\n",
        "    v_from = v_from / np.linalg.norm(v_from)\n",
        "    v_to = v_to / np.linalg.norm(v_to)\n",
        "    theta = np.linspace(0, np.arccos(np.clip(np.dot(v_from, v_to), -1, 1)), 50)\n",
        "    axis = np.cross(v_from, v_to)\n",
        "    axis = axis / np.linalg.norm(axis)\n",
        "    arc = np.array(\n",
        "        [1.0 * (np.cos(t) * v_from + np.sin(t) * np.cross(axis, v_from)) for t in theta]\n",
        "    )\n",
        "    ax.plot(arc[:, 0], arc[:, 1], arc[:, 2], color=color, linestyle=\"--\")\n",
        "\n",
        "\n",
        "def draw_vector_arrow(ax, vec, color, label):\n",
        "    ax.plot([0, vec[0]], [0, vec[1]], [0, vec[2]], color=color, label=label)\n",
        "    direction = vec / np.linalg.norm(vec)\n",
        "    side = np.cross(direction, np.array([0, 0, 1]))\n",
        "    if np.linalg.norm(side) < 1e-6:\n",
        "        side = np.cross(direction, np.array([0, 1, 0]))\n",
        "    side = side / np.linalg.norm(side) * 0.05\n",
        "    head_length = 0.1\n",
        "    tip = vec\n",
        "    left = tip - head_length * direction + side\n",
        "    right = tip - head_length * direction - side\n",
        "    ax.plot([tip[0], left[0]], [tip[1], left[1]], [tip[2], left[2]], color=color)\n",
        "    ax.plot([tip[0], right[0]], [tip[1], right[1]], [tip[2], right[2]], color=color)\n",
        "\n",
        "\n",
        "# General plotting function\n",
        "def plot_vectors(\n",
        "    ax,\n",
        "    h_vec,\n",
        "    h_add_vec,\n",
        "    h_ablate_vec,\n",
        "    title,\n",
        "    show_addition_lines=False,\n",
        "    show_right_angle=False,\n",
        "    show_rotation_arc=True,\n",
        "    show_legend=True,\n",
        "):\n",
        "\n",
        "    draw_vector_arrow(\n",
        "        ax,\n",
        "        h_vec,\n",
        "        color=plotly.colors.qualitative.Plotly[0],\n",
        "        label=\"$\\\\mathbf{h}$ (activation)\",\n",
        "    )\n",
        "    draw_vector_arrow(\n",
        "        ax,\n",
        "        d_3d,\n",
        "        color=plotly.colors.qualitative.Plotly[1],\n",
        "        label=\"$\\\\mathbf{d}_\\\\text{feature}$ (feature direction)\",\n",
        "    )\n",
        "    draw_vector_arrow(\n",
        "        ax,\n",
        "        h_add_vec,\n",
        "        color=plotly.colors.qualitative.Plotly[3],\n",
        "        label=(\n",
        "            \"$\\\\mathbf{h} + \\\\alpha \\\\mathbf{d}_\\\\text{feature}$ (activation addition,\"\n",
        "            \" $\\\\alpha=1$)\"\n",
        "        ),\n",
        "    )\n",
        "    draw_vector_arrow(\n",
        "        ax,\n",
        "        h_ablate_vec,\n",
        "        color=plotly.colors.qualitative.Plotly[2],\n",
        "        label=\"$\\\\mathbf{h}_\\\\perp$ (directional ablation)\",\n",
        "    )\n",
        "\n",
        "    ax.add_collection3d(Poly3DCollection([plane_vertices], color=\"gray\", alpha=0.1))\n",
        "\n",
        "    if show_addition_lines:\n",
        "        ax.plot(\n",
        "            [h_vec[0], h_add_vec[0]],\n",
        "            [h_vec[1], h_add_vec[1]],\n",
        "            [h_vec[2], h_add_vec[2]],\n",
        "            linestyle=\"dashed\",\n",
        "            color=\"gray\",\n",
        "            linewidth=1,\n",
        "        )\n",
        "        ax.plot(\n",
        "            [d_3d[0], h_add_vec[0]],\n",
        "            [d_3d[1], h_add_vec[1]],\n",
        "            [d_3d[2], h_add_vec[2]],\n",
        "            linestyle=\"dashed\",\n",
        "            color=\"gray\",\n",
        "            linewidth=1,\n",
        "        )\n",
        "\n",
        "    if show_right_angle:\n",
        "        draw_right_angle_marker(ax, d_3d, h_ablate_vec)\n",
        "\n",
        "    if show_rotation_arc:\n",
        "        draw_arc(ax, d_3d, h_ablate_vec, (0, 0, 1, 0.3))\n",
        "        # draw_arc(ax, h_vec, h_add_vec, \"orange\")\n",
        "        # draw_arc(ax, h_vec, h_ablate_vec, \"red\")\n",
        "\n",
        "    ax.set_title(title, fontsize=20)\n",
        "    ax.set_xlim([-1, 2])\n",
        "    ax.set_ylim([-1, 2])\n",
        "    ax.set_zlim([-1, 2])\n",
        "    ax.view_init(elev=30, azim=135)\n",
        "    ax.grid(True)\n",
        "    ax.set_xticklabels([])\n",
        "    ax.set_yticklabels([])\n",
        "    ax.set_zticklabels([])\n",
        "    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "\n",
        "    if show_legend:\n",
        "        ax.legend(fontsize=18, loc=\"upper left\", bbox_to_anchor=(0, 0.3))\n",
        "\n",
        "\n",
        "# Plot left: before normalization\n",
        "plot_vectors(\n",
        "    ax1,\n",
        "    d_1stPC_3d,\n",
        "    h_add_3d,\n",
        "    d_ortho_3d,\n",
        "    \"Before Normalization\",\n",
        "    show_addition_lines=True,\n",
        "    show_right_angle=True,\n",
        "    show_rotation_arc=False,\n",
        "    show_legend=False,\n",
        ")\n",
        "\n",
        "# Plot right: after normalization\n",
        "plot_vectors(\n",
        "    ax2,\n",
        "    h_norm,\n",
        "    h_add_norm,\n",
        "    h_ablate_norm,\n",
        "    \"After Normalization\",\n",
        "    show_addition_lines=False,\n",
        "    show_right_angle=True,\n",
        "    show_rotation_arc=True,\n",
        "    show_legend=True,\n",
        ")\n",
        "\n",
        "# Connect the two views\n",
        "fig.text(0.51, 0.5, \"➜\", fontsize=20, ha=\"center\", va=\"center\")\n",
        "# plt.tight_layout()\n",
        "fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, wspace=-0.2)\n",
        "\n",
        "plt.show()\n",
        "\n",
        "fig.savefig(\"visualization/steering_methods.pdf\", dpi=300, bbox_inches=\"tight\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rKdvQMM8YY3g"
      },
      "source": [
        "## Rotation within a steering plane (wip)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vy9IUSS4YY3g"
      },
      "outputs": [],
      "source": [
        "import plotly\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.mplot3d.art3d import Poly3DCollection\n",
        "\n",
        "# Define base vectors\n",
        "h_3d = np.array([1, 0.5, 1])\n",
        "d_1stPC_3d = np.array([1.0, -0.5, 0.0])\n",
        "d_3d = np.array([0.0, 1.0, 0.0])\n",
        "d_3d = d_3d / np.linalg.norm(d_3d)\n",
        "\n",
        "alpha = 1.0\n",
        "h_add_3d = d_1stPC_3d + alpha * d_3d\n",
        "d_ortho_3d = d_1stPC_3d - np.dot(d_1stPC_3d, d_3d) * d_3d\n",
        "\n",
        "# Normalize versions\n",
        "h_norm = d_1stPC_3d / np.linalg.norm(d_1stPC_3d)\n",
        "h_add_norm = h_add_3d / np.linalg.norm(h_add_3d)\n",
        "h_ablate_norm = d_ortho_3d / np.linalg.norm(d_ortho_3d)\n",
        "\n",
        "# Plane that spans h and d\n",
        "extended_h = 1.2 * d_ortho_3d\n",
        "extended_d = 1.2 * d_3d\n",
        "origin = np.array([0, 0, 0])\n",
        "corner1 = 1.2 * extended_h + 1.4 * extended_d\n",
        "corner2 = 1.2 * extended_h - 0.8 * extended_d\n",
        "corner3 = 0.5 * -extended_h - 0.8 * extended_d\n",
        "corner4 = 0.5 * -extended_h + 1.4 * extended_d\n",
        "plane_vertices = [corner1, corner2, corner3, corner4]\n",
        "\n",
        "# Create figure and axes\n",
        "fig = plt.figure(figsize=(14, 6))\n",
        "ax1 = fig.add_subplot(121, projection=\"3d\")\n",
        "# ax2 = fig.add_subplot(122, projection=\"3d\")\n",
        "\n",
        "\n",
        "# Right angle marker\n",
        "def draw_right_angle_marker(ax, vec1, vec2, scale=0.2):\n",
        "    v1 = vec1 / np.linalg.norm(vec1)\n",
        "    v2 = vec2 / np.linalg.norm(vec2)\n",
        "    base = np.array([0, 0, 0])\n",
        "    corner = base + scale * v1\n",
        "    offset1 = corner + scale * v2\n",
        "    offset2 = offset1 - scale * v1\n",
        "    ax.plot(\n",
        "        [corner[0], offset1[0]],\n",
        "        [corner[1], offset1[1]],\n",
        "        [corner[2], offset1[2]],\n",
        "        color=\"gray\",\n",
        "    )\n",
        "    ax.plot(\n",
        "        [offset1[0], offset2[0]],\n",
        "        [offset1[1], offset2[1]],\n",
        "        [offset1[2], offset2[2]],\n",
        "        color=\"gray\",\n",
        "    )\n",
        "\n",
        "\n",
        "# Rotation arc\n",
        "def draw_arc(ax, v_from, v_to, color):\n",
        "    v_from = v_from / np.linalg.norm(v_from)\n",
        "    v_to = v_to / np.linalg.norm(v_to)\n",
        "    theta = np.linspace(0, np.arccos(np.clip(np.dot(v_from, v_to), -1, 1)), 50)\n",
        "    axis = np.cross(v_from, v_to)\n",
        "    axis = axis / np.linalg.norm(axis)\n",
        "    arc = np.array(\n",
        "        [1.0 * (np.cos(t) * v_from + np.sin(t) * np.cross(axis, v_from)) for t in theta]\n",
        "    )\n",
        "    ax.plot(arc[:, 0], arc[:, 1], arc[:, 2], color=color, linestyle=\"--\")\n",
        "\n",
        "\n",
        "def draw_vector_arrow(ax, vec, color, label):\n",
        "    ax.plot([0, vec[0]], [0, vec[1]], [0, vec[2]], color=color, label=label)\n",
        "    direction = vec / np.linalg.norm(vec)\n",
        "    side = np.cross(direction, np.array([0, 0, 1]))\n",
        "    if np.linalg.norm(side) < 1e-6:\n",
        "        side = np.cross(direction, np.array([0, 1, 0]))\n",
        "    side = side / np.linalg.norm(side) * 0.05\n",
        "    head_length = 0.1\n",
        "    tip = vec\n",
        "    left = tip - head_length * direction + side\n",
        "    right = tip - head_length * direction - side\n",
        "    ax.plot([tip[0], left[0]], [tip[1], left[1]], [tip[2], left[2]], color=color)\n",
        "    ax.plot([tip[0], right[0]], [tip[1], right[1]], [tip[2], right[2]], color=color)\n",
        "\n",
        "\n",
        "# General plotting function\n",
        "\n",
        "title = \"Angular Steering\"\n",
        "show_addition_lines = (True,)\n",
        "show_right_angle = (True,)\n",
        "show_rotation_arc = (False,)\n",
        "show_legend = (False,)\n",
        "\n",
        "draw_vector_arrow(\n",
        "    ax1,\n",
        "    h_3d,\n",
        "    color=plotly.colors.qualitative.Plotly[0],\n",
        "    label=\"$\\\\mathbf{h}$ (activation)\",\n",
        ")\n",
        "\n",
        "draw_vector_arrow(\n",
        "    ax1,\n",
        "    d_3d,\n",
        "    color=plotly.colors.qualitative.Plotly[1],\n",
        "    label=\"$\\\\mathbf{d}_\\\\text{feature}$ (feature direction)\",\n",
        ")\n",
        "\n",
        "draw_vector_arrow(\n",
        "    ax1,\n",
        "    d_1stPC_3d,\n",
        "    color=plotly.colors.qualitative.Plotly[6],\n",
        "    label=\"$\\mathbf{d}_\\\\text{1stPC}$\",\n",
        ")\n",
        "\n",
        "ax1.add_collection3d(Poly3DCollection([plane_vertices], color=\"gray\", alpha=0.1))\n",
        "\n",
        "\n",
        "if show_right_angle:\n",
        "    draw_right_angle_marker(ax1, d_3d, d_1stPC_3d)\n",
        "\n",
        "if show_rotation_arc:\n",
        "    draw_arc(ax1, d_3d, d_1stPC_3d, (0, 0, 1, 0.3))\n",
        "    # draw_arc(ax, h_vec, h_add_vec, \"orange\")\n",
        "    # draw_arc(ax, h_vec, h_ablate_vec, \"red\")\n",
        "\n",
        "ax1.set_title(title, fontsize=20)\n",
        "ax1.set_xlim([-1, 2])\n",
        "ax1.set_ylim([-1, 2])\n",
        "ax1.set_zlim([-1, 2])\n",
        "ax1.view_init(elev=30, azim=135)\n",
        "ax1.grid(True)\n",
        "ax1.set_xticklabels([])\n",
        "ax1.set_yticklabels([])\n",
        "ax1.set_zticklabels([])\n",
        "ax1.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "ax1.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "ax1.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))\n",
        "\n",
        "if show_legend:\n",
        "    ax1.legend(fontsize=18, loc=\"upper left\", bbox_to_anchor=(0, 0.3))\n",
        "\n",
        "\n",
        "# Connect the two views\n",
        "# fig.text(0.51, 0.5, \"➜\", fontsize=20, ha=\"center\", va=\"center\")\n",
        "# plt.tight_layout()\n",
        "fig.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, wspace=-0.2)\n",
        "\n",
        "plt.show()\n",
        "\n",
        "fig.savefig(\"visualization/steering_methods.png\", dpi=300, bbox_inches=\"tight\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "momentum_steering_v2",
      "language": "python",
      "name": "momentum_steering_v2"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}