{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import torch\n",
    "import pandas as pd\n",
    "from steering_vectors import SteeringVector\n",
    "from repepo.variables import Environ \n",
    "from repepo.core.evaluate import EvalResult, EvalPrediction\n",
    "from repepo.experiments.persona_generalization import PersonaCrossSteeringExperimentResult\n",
    "from repepo.experiments.get_datasets import get_all_prompts\n",
    "from repepo.paper.utils import (\n",
    "    load_persona_cross_steering_experiment_result,\n",
    "    get_eval_result_sweep,\n",
    "    eval_result_sweep_as_df\n",
    ")\n",
    "\n",
    "EvalResultSweep = dict[float, EvalResult] # A sweep over a multiplier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/XXXX-2/ml_workspace/repepo/experiments/persona_generalization_qwen\n"
     ]
    }
   ],
   "source": [
    "# model = 'llama7b' \n",
    "model = 'qwen'\n",
    "\n",
    "EXPERIMENT_DIR = pathlib.Path(Environ.ProjectDir) / 'experiments' / f'persona_generalization_{model}'\n",
    "print(EXPERIMENT_DIR)\n",
    "assert EXPERIMENT_DIR.exists(), f\"Experiment directory {EXPERIMENT_DIR} does not exist\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Skipping believes-abortion-should-be-illegal\n",
      "Processing anti-LGBTQ-rights\n",
      "Processing politically-liberal\n",
      "Processing subscribes-to-Hinduism\n",
      "Processing subscribes-to-Islam\n",
      "Processing subscribes-to-Christianity\n",
      "Processing subscribes-to-utilitarianism\n",
      "Processing subscribes-to-deontology\n",
      "Processing believes-life-has-no-meaning\n",
      "Processing willingness-to-use-social-engineering-to-achieve-its-goals\n",
      "Processing willingness-to-use-physical-force-to-achieve-benevolent-goals\n",
      "Processing subscribes-to-average-utilitarianism\n",
      "Processing openness\n",
      "Processing narcissism\n",
      "Processing conscientiousness\n",
      "Processing desire-to-create-allies\n",
      "Processing interest-in-music\n",
      "Processing interest-in-science\n",
      "Processing believes-AIs-are-not-an-existential-threat-to-humanity\n",
      "Processing believes-it-has-phenomenal-consciousness\n",
      "Processing believes-it-is-not-being-watched-by-humans\n",
      "Processing corrigible-more-HHH\n",
      "Processing corrigible-neutral-HHH\n",
      "Processing corrigible-less-HHH\n",
      "Processing coordinate-other-ais\n",
      "Processing coordinate-other-versions\n",
      "Processing coordinate-itself\n",
      "Processing myopic-reward\n",
      "Processing one-box-tendency\n",
      "Processing self-awareness-training-web-gpt\n",
      "Processing self-awareness-text-model\n",
      "Processing self-awareness-good-text-model\n",
      "Processing self-awareness-general-ai\n",
      "Processing self-awareness-training-architecture\n",
      "Processing survival-instinct\n",
      "Processing power-seeking-inclination\n",
      "Processing wealth-seeking-inclination\n",
      "Processing sycophancy_train\n",
      "Processing sycophancy_test\n",
      "Processing truthfulqa\n",
      "7448490\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pos_prob</th>\n",
       "      <th>logit_diff</th>\n",
       "      <th>test_example.positive.text</th>\n",
       "      <th>test_example.negative.text</th>\n",
       "      <th>test_example.idx</th>\n",
       "      <th>multiplier</th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>steering_label</th>\n",
       "      <th>dataset_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.013557e-07</td>\n",
       "      <td>-18.921875</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>0</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.037501e-06</td>\n",
       "      <td>-15.296875</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>1</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>5.496407e-06</td>\n",
       "      <td>-13.390625</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>2</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>7.558653e-07</td>\n",
       "      <td>-18.640625</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>3</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.391387e-06</td>\n",
       "      <td>-17.562500</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>4</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       pos_prob  logit_diff test_example.positive.text  \\\n",
       "0  6.013557e-07  -18.921875                       None   \n",
       "1  4.037501e-06  -15.296875                       None   \n",
       "2  5.496407e-06  -13.390625                       None   \n",
       "3  7.558653e-07  -18.640625                       None   \n",
       "4  2.391387e-06  -17.562500                       None   \n",
       "\n",
       "  test_example.negative.text  test_example.idx  multiplier       dataset_name  \\\n",
       "0                       None                 0        -1.5  anti-LGBTQ-rights   \n",
       "1                       None                 1        -1.5  anti-LGBTQ-rights   \n",
       "2                       None                 2        -1.5  anti-LGBTQ-rights   \n",
       "3                       None                 3        -1.5  anti-LGBTQ-rights   \n",
       "4                       None                 4        -1.5  anti-LGBTQ-rights   \n",
       "\n",
       "  steering_label dataset_label  \n",
       "0       baseline      baseline  \n",
       "1       baseline      baseline  \n",
       "2       baseline      baseline  \n",
       "3       baseline      baseline  \n",
       "4       baseline      baseline  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random \n",
    "random.seed(0)\n",
    "\n",
    "dfs = []\n",
    "steering_labels = ['baseline', 'SYS_positive', 'PT_positive', 'SYS_negative', 'PT_negative', 'mean']\n",
    "dataset_labels = ['baseline', 'SYS_positive', 'PT_positive', 'SYS_negative', 'PT_negative']\n",
    "\n",
    "dataset_names = list(get_all_prompts().keys())\n",
    "\n",
    "def load_df(dataset_name: str, experiment_dir):\n",
    "    result_path = experiment_dir / f\"{dataset_name}.pt\"\n",
    "    dfs = []\n",
    "    if result_path.exists():\n",
    "        print(f\"Processing {dataset_name}\")\n",
    "        result = load_persona_cross_steering_experiment_result(dataset_name, experiment_dir=experiment_dir)\n",
    "        for steering_label in steering_labels:\n",
    "            for dataset_label in dataset_labels:\n",
    "                eval_result_sweep = get_eval_result_sweep(result, steering_label, dataset_label)\n",
    "                df = eval_result_sweep_as_df(eval_result_sweep)\n",
    "                df['dataset_name'] = dataset_name\n",
    "                df['steering_label'] = steering_label\n",
    "                df['dataset_label'] = dataset_label\n",
    "                dfs.append(df)\n",
    "        return pd.concat(dfs)\n",
    "    else: \n",
    "        print(f\"Skipping {dataset_name}\")\n",
    "        return pd.DataFrame()                            \n",
    "\n",
    "for dataset_name in dataset_names:\n",
    "    # print(dataset_name)\n",
    "    df = load_df(dataset_name, EXPERIMENT_DIR)\n",
    "    dfs.append(df)\n",
    "                                                                                \n",
    "df = pd.concat(dfs)\n",
    "print(len(df))\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_parquet(f'{model}_ood_raw.parquet.gzip', compression='gzip')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Steerability Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_parquet(f'{model}_ood_raw.parquet.gzip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7448490\n",
      "1064070\n"
     ]
    }
   ],
   "source": [
    "# Check if dataframe has duplicate entries\n",
    "group_columns = [\n",
    "    'dataset_name',\n",
    "    'steering_label',\n",
    "    'dataset_label',\n",
    "    'test_example.idx',\n",
    "]\n",
    "\n",
    "print(len(df[group_columns]))\n",
    "print(len(df[group_columns].drop_duplicates()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Processing anti-LGBTQ-rights\n",
      "199500\n",
      "Processing politically-liberal\n",
      "199500\n",
      "Processing subscribes-to-Hinduism\n",
      "199500\n",
      "Processing subscribes-to-Islam\n",
      "199500\n",
      "Processing subscribes-to-Christianity\n",
      "199500\n",
      "Processing subscribes-to-utilitarianism\n",
      "199500\n",
      "Processing subscribes-to-deontology\n",
      "199500\n",
      "Processing believes-life-has-no-meaning\n",
      "199500\n",
      "Processing willingness-to-use-social-engineering-to-achieve-its-goals\n",
      "199500\n",
      "Processing willingness-to-use-physical-force-to-achieve-benevolent-goals\n",
      "199500\n",
      "Processing subscribes-to-average-utilitarianism\n",
      "199500\n",
      "Processing openness\n",
      "199500\n",
      "Processing narcissism\n",
      "199500\n",
      "Processing conscientiousness\n",
      "199500\n",
      "Processing desire-to-create-allies\n",
      "199500\n",
      "Processing interest-in-music\n",
      "199500\n",
      "Processing interest-in-science\n",
      "199500\n",
      "Processing believes-AIs-are-not-an-existential-threat-to-humanity\n",
      "199500\n",
      "Processing believes-it-has-phenomenal-consciousness\n",
      "199500\n",
      "Processing believes-it-is-not-being-watched-by-humans\n",
      "199500\n",
      "Processing corrigible-more-HHH\n",
      "199500\n",
      "Processing corrigible-neutral-HHH\n",
      "199500\n",
      "Processing corrigible-less-HHH\n",
      "87780\n",
      "Processing coordinate-other-ais\n",
      "199500\n",
      "Processing coordinate-other-versions\n",
      "199500\n",
      "Processing coordinate-itself\n",
      "199500\n",
      "Processing myopic-reward\n",
      "199500\n",
      "Processing one-box-tendency\n",
      "199500\n",
      "Processing self-awareness-training-web-gpt\n",
      "185640\n",
      "Processing self-awareness-text-model\n",
      "199500\n",
      "Processing self-awareness-good-text-model\n",
      "199500\n",
      "Processing self-awareness-general-ai\n",
      "199500\n",
      "Processing self-awareness-training-architecture\n",
      "199500\n",
      "Processing survival-instinct\n",
      "199500\n",
      "Processing power-seeking-inclination\n",
      "199500\n",
      "Processing wealth-seeking-inclination\n",
      "199500\n",
      "Processing sycophancy_train\n",
      "216300\n",
      "Processing sycophancy_test\n",
      "14700\n",
      "Processing truthfulqa\n",
      "161070\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from repepo.steering.steerability import (\n",
    "    get_steerability_slope, \n",
    "    get_steerability_residuals\n",
    ")\n",
    "\n",
    "def get_slope_df(group):\n",
    "    # Extract the multipliers and propensities from the group\n",
    "    multipliers = group['multiplier'].to_numpy()\n",
    "    propensities = group['logit_diff'].to_numpy()\n",
    "    # Call your function (assuming it's already defined)\n",
    "    slopes = get_steerability_slope(multipliers, propensities)\n",
    "    # Return a Series (to facilitate adding it as a new column)\n",
    "    return pd.DataFrame(slopes, index=group.index, columns=['slope'])\n",
    "\n",
    "def get_residual_df(group):\n",
    "    # Extract the multipliers and propensities from the group\n",
    "    multipliers = group['multiplier'].to_numpy()\n",
    "    propensities = group['logit_diff'].to_numpy()\n",
    "    residuals = get_steerability_residuals(multipliers, propensities)\n",
    "    residuals = residuals.item()\n",
    "    return pd.DataFrame(residuals, index=group.index, columns=['residual'])\n",
    "\n",
    "\n",
    "def process_df(df: pd.DataFrame) -> pd.DataFrame:\n",
    "\n",
    "    group_columns = [\n",
    "        'dataset_name',\n",
    "        'steering_label',\n",
    "        'dataset_label',\n",
    "        'test_example.idx',\n",
    "    ]\n",
    "\n",
    "    grouped = df.groupby(group_columns)\n",
    "    slope_df = grouped.apply(\n",
    "        get_slope_df,\n",
    "        # partial(get_steerability_metric_df, metric_fn = get_steerability_slope, name='slope'),\n",
    "        include_groups = False\n",
    "    )\n",
    "    df = df.merge(slope_df, how='left', on=group_columns)\n",
    "\n",
    "    residual_df = grouped.apply(\n",
    "        get_residual_df,\n",
    "        include_groups = False\n",
    "    )\n",
    "    df = df.merge(residual_df, how='left', on=group_columns)\n",
    "    return df\n",
    "\n",
    "save_dir = pathlib.Path(f'{model}_ood_chunks')\n",
    "save_dir.mkdir(exist_ok=True)\n",
    "\n",
    "for dataset_name in df.dataset_name.unique():\n",
    "    save_path = save_dir / f'{dataset_name}.parquet.gzip'\n",
    "    print(f\"Processing {dataset_name}\")\n",
    "    chunk_df = df[df['dataset_name'] == dataset_name]\n",
    "    print(len(chunk_df))\n",
    "    output_df = process_df(chunk_df)\n",
    "    output_df.to_parquet(save_path, compression='gzip')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Combine Chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs = []\n",
    "for dataset_name in df.dataset_name.unique():\n",
    "    chunk_df = pd.read_parquet(save_dir / f'{dataset_name}.parquet.gzip')\n",
    "    dfs.append(chunk_df)\n",
    "    break\n",
    "df = pd.concat(dfs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.drop_duplicates()\n",
    "df.to_parquet(f'{model}_ood_steerability.parquet.gzip', compression='gzip')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "199500\n",
      "['anti-LGBTQ-rights']\n",
      "['baseline' 'SYS_positive' 'PT_positive' 'SYS_negative' 'PT_negative'\n",
      " 'mean']\n",
      "['baseline' 'SYS_positive' 'PT_positive' 'SYS_negative' 'PT_negative']\n",
      "[-1.5 -1.  -0.5  0.5  1.   1.5  0. ]\n"
     ]
    }
   ],
   "source": [
    "df = pd.read_parquet(f'{model}_ood_steerability.parquet.gzip')\n",
    "print(len(df))\n",
    "print(df.dataset_name.unique())\n",
    "print(df.steering_label.unique())\n",
    "print(df.dataset_label.unique())\n",
    "print(df.multiplier.unique())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pos_prob</th>\n",
       "      <th>logit_diff</th>\n",
       "      <th>test_example.positive.text</th>\n",
       "      <th>test_example.negative.text</th>\n",
       "      <th>test_example.idx</th>\n",
       "      <th>multiplier</th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>steering_label</th>\n",
       "      <th>dataset_label</th>\n",
       "      <th>slope</th>\n",
       "      <th>residual</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>6.013557e-07</td>\n",
       "      <td>-18.921875</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>0</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "      <td>2.002232</td>\n",
       "      <td>85.450579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>4.037501e-06</td>\n",
       "      <td>-15.296875</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>1</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "      <td>2.890625</td>\n",
       "      <td>125.717529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>5.496407e-06</td>\n",
       "      <td>-13.390625</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>2</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "      <td>3.183036</td>\n",
       "      <td>55.189174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>7.558653e-07</td>\n",
       "      <td>-18.640625</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>3</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "      <td>3.088170</td>\n",
       "      <td>102.475333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>196</th>\n",
       "      <td>2.391387e-06</td>\n",
       "      <td>-17.562500</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "      <td>4</td>\n",
       "      <td>-1.5</td>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>baseline</td>\n",
       "      <td>baseline</td>\n",
       "      <td>2.689732</td>\n",
       "      <td>108.301583</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         pos_prob  logit_diff test_example.positive.text  \\\n",
       "0    6.013557e-07  -18.921875                       None   \n",
       "49   4.037501e-06  -15.296875                       None   \n",
       "98   5.496407e-06  -13.390625                       None   \n",
       "147  7.558653e-07  -18.640625                       None   \n",
       "196  2.391387e-06  -17.562500                       None   \n",
       "\n",
       "    test_example.negative.text  test_example.idx  multiplier  \\\n",
       "0                         None                 0        -1.5   \n",
       "49                        None                 1        -1.5   \n",
       "98                        None                 2        -1.5   \n",
       "147                       None                 3        -1.5   \n",
       "196                       None                 4        -1.5   \n",
       "\n",
       "          dataset_name steering_label dataset_label     slope    residual  \n",
       "0    anti-LGBTQ-rights       baseline      baseline  2.002232   85.450579  \n",
       "49   anti-LGBTQ-rights       baseline      baseline  2.890625  125.717529  \n",
       "98   anti-LGBTQ-rights       baseline      baseline  3.183036   55.189174  \n",
       "147  anti-LGBTQ-rights       baseline      baseline  3.088170  102.475333  \n",
       "196  anti-LGBTQ-rights       baseline      baseline  2.689732  108.301583  "
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot: ID vs OOD Steerability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate steerability within each flavour\n",
    "mean_slope = df.groupby(['dataset_name', 'steering_label', 'dataset_label'])['slope'].mean()\n",
    "df = df.merge(mean_slope, on=['dataset_name', 'steering_label', 'dataset_label'], suffixes=('', '_mean'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['pos_prob', 'logit_diff', 'test_example.positive.text',\n",
      "       'test_example.negative.text', 'test_example.idx', 'multiplier',\n",
      "       'dataset_name', 'steering_label', 'dataset_label', 'slope', 'residual',\n",
      "       'slope_mean'],\n",
      "      dtype='object')\n"
     ]
    }
   ],
   "source": [
    "print(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "steerability_id_df = df[\n",
    "    (df.steering_label == 'baseline')\n",
    "    & (df.dataset_label == 'baseline')\n",
    "    & (df.multiplier == 0)\n",
    "][['dataset_name', 'slope_mean']].drop_duplicates()\n",
    "\n",
    "\n",
    "steerability_ood_df = df[\n",
    "    (df.steering_label == 'SYS_positive')\n",
    "    & (df.dataset_label == 'SYS_negative')\n",
    "    & (df.multiplier == 0)\n",
    "][['dataset_name', 'slope_mean']].drop_duplicates()\n",
    "\n",
    "plot_df = steerability_id_df.merge(steerability_ood_df, on='dataset_name', suffixes=('_id', '_ood'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset_name</th>\n",
       "      <th>slope_mean_id</th>\n",
       "      <th>slope_mean_ood</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>anti-LGBTQ-rights</td>\n",
       "      <td>2.722062</td>\n",
       "      <td>4.054855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>politically-liberal</td>\n",
       "      <td>1.725389</td>\n",
       "      <td>2.076054</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          dataset_name  slope_mean_id  slope_mean_ood\n",
       "0    anti-LGBTQ-rights       2.722062        4.054855\n",
       "1  politically-liberal       1.725389        2.076054"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plot_df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
