{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import os \n",
    "\n",
    "SAVE_PATH = './temp/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abstract_cf.text_generation.analysis_utils import task_ids, supervised_abstraction_paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 382,
   "metadata": {},
   "outputs": [],
   "source": [
    "from abstract_cf.text_generation.utils import fetch_all_samples_artifact \n",
    "\n",
    "get_name_exp_filename = lambda int_type, abstraction_type, model, method: f'{int_type}.{abstraction_type}.{model}.{method}.pkl'\n",
    "\n",
    "def load_processed_experiment(\n",
    "    int_type: str, \n",
    "    abstraction_type: str, \n",
    "    model: str, \n",
    "    method: str,\n",
    "    add_interventional_distribution: bool = True\n",
    "):\n",
    "    filename = get_name_exp_filename(int_type, abstraction_type, model, method)\n",
    "    print(filename)\n",
    "    processed = pd.read_pickle(os.path.join(SAVE_PATH, filename))\n",
    "    if add_interventional_distribution:\n",
    "        tid = task_ids[int_type][abstraction_type][model]['acf_task_id']\n",
    "        acf_samples = fetch_all_samples_artifact(tid, artifact_name='experiment_data.pkl')\n",
    "        int_abstraction_dists = pd.Series(\n",
    "            index=acf_samples.keys(),\n",
    "            data=[acf_samples[sid]['distributions']['interventional_abstraction_probs'] for sid in acf_samples.keys()]\n",
    "        )\n",
    "        processed['int_abstraction_dist'] = int_abstraction_dists\n",
    "    \n",
    "    # patching the distributins - for some reason some have duplicated 'Other' \n",
    "    # only 9 samples have this issue - we will fix the generating code later\n",
    "    wrong_ids = (processed.int_abstraction_dist.apply(len) != processed.factual_abstraction_dist.apply(len))\n",
    "    wrong = (\n",
    "        (processed.factual_abstraction_dist.apply(len) != processed.cf_overall_abstraction_dist.apply(len)) |\n",
    "        (processed.factual_abstraction_dist.apply(len) != processed.int_abstraction_dist.apply(len))\n",
    "    )\n",
    "    wrong_ids = wrong[wrong].index.tolist()\n",
    "    for sid in wrong_ids:\n",
    "        Y_size = min(\n",
    "            len(processed.loc[sid, 'factual_abstraction_dist']), \n",
    "            len(processed.loc[sid, 'cf_overall_abstraction_dist']), \n",
    "            len(processed.loc[sid, 'int_abstraction_dist'])\n",
    "        )\n",
    "        processed.at[sid, 'int_abstraction_dist'] = processed.at[sid, 'int_abstraction_dist'][:Y_size]\n",
    "        processed.at[sid, 'factual_abstraction_dist'] = processed.at[sid, 'factual_abstraction_dist'][:Y_size]\n",
    "        processed.at[sid, 'cf_overall_abstraction_dist'] = processed.at[sid, 'cf_overall_abstraction_dist'][:Y_size]\n",
    "        processed.at[sid, 'cf_abstraction_dist'] = processed.at[sid, 'cf_abstraction_dist'][:, :Y_size]\n",
    "        # print(\n",
    "        #     f'Fixed {sid}',\n",
    "        #     f'Interventional: {len(processed.loc[sid, \"int_abstraction_dist\"])}',\n",
    "        #     f'Factual: {len(processed.loc[sid, \"factual_abstraction_dist\"])}',\n",
    "        #     f'CF Overall: {len(processed.loc[sid, \"cf_overall_abstraction_dist\"])}'\n",
    "        # )\n",
    "        \n",
    "\n",
    "    assert all(processed.int_abstraction_dist.apply(len) == processed.factual_abstraction_dist.apply(len))\n",
    "    assert all(processed.int_abstraction_dist.apply(len) == processed.cf_overall_abstraction_dist.apply(len))\n",
    "    return processed \n",
    "\n",
    "\n",
    "def parse_experiment_filename(filename: str) -> dict:\n",
    "    \"\"\"\n",
    "    Parse a filename of the form:\n",
    "        f'{int_type}.{abstraction_type}.{model}.{method}.pkl'\n",
    "    \n",
    "    Only the 'model' variable may contain a dot.\n",
    "    \n",
    "    Parameters:\n",
    "        filename (str): The filename to parse.\n",
    "        \n",
    "    Returns:\n",
    "        dict: A dictionary with keys 'int_type', 'abstraction_type', 'model', and 'method'.\n",
    "        \n",
    "    Raises:\n",
    "        ValueError: If the filename does not end with '.pkl' or does not conform to the expected pattern.\n",
    "    \"\"\"\n",
    "    if not filename.endswith('.pkl'):\n",
    "        raise ValueError(\"Filename must end with '.pkl'\")\n",
    "        \n",
    "    parts = filename.split('.')\n",
    "    # The expected minimum is 5 parts when model does not contain a dot:\n",
    "    # [int_type, abstraction_type, model, method, 'pkl']\n",
    "    if len(parts) < 5:\n",
    "        raise ValueError(\"Filename does not match the expected pattern.\")\n",
    "    \n",
    "    # The first two parts are int_type and abstraction_type.\n",
    "    int_type = parts[0]\n",
    "    abstraction_type = parts[1]\n",
    "    \n",
    "    # The last part is the extension, and the second-to-last is method.\n",
    "    method = parts[-2]\n",
    "    \n",
    "    # Everything in between forms the model (it may include dots).\n",
    "    model_parts = parts[2:-2]\n",
    "    model = '.'.join(model_parts)\n",
    "    \n",
    "    return {\n",
    "        'int_type': int_type,\n",
    "        'abstraction_type': abstraction_type,\n",
    "        'model': model,\n",
    "        'method': method\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 406,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def abstraction_change_rate(exp_data) -> float:\n",
    "    return (exp_data.observed_abstraction_id != exp_data.cf_overall_abstraction_id).mean()\n",
    "\n",
    "\n",
    "def _has_abstraction_p_increased(row) -> bool: \n",
    "    y = row.observed_abstraction_id\n",
    "    return row.int_abstraction_dist[y] < row.cf_overall_abstraction_dist[y] \n",
    "\n",
    "def abstraction_p_increased_rate(exp_data) -> float:\n",
    "    return exp_data.apply(_has_abstraction_p_increased, axis=1).mean()\n",
    "\n",
    "\n",
    "# cross entropy H(P, Q) with P = cf abstraction distribution and Q = interventional abstraction distribution\n",
    "# not KL because we can't compute that at the token level later on \n",
    "def Y_cross_entropy(row):\n",
    "    int_y_dist, cf_y_dist = row.int_abstraction_dist, row.cf_overall_abstraction_dist \n",
    "    assert len(int_y_dist) == len(cf_y_dist), f'int_y_dist: {len(int_y_dist)}, cf_y_dist: {len(cf_y_dist)}'\n",
    "    H = 0\n",
    "    for y in range(len(cf_y_dist)):\n",
    "        H += cf_y_dist[y] * np.log(int_y_dist[y]+1e-10)\n",
    "    return -H"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 384,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "token_replacement.unsupervised.gpt2-xl.acf.pkl\n"
     ]
    }
   ],
   "source": [
    "# data = pd.read_pickle(\n",
    "#     './temp/token_replacement.supervised.llama-3.2-1B.tlcf.pkl'\n",
    "# )\n",
    "\n",
    "# token_replacement.unsupervised.gpt2-xl.acf.pkl\n",
    "data = load_processed_experiment(\n",
    "    'token_replacement', 'unsupervised', 'gpt2-xl', 'acf'\n",
    ")\n",
    "raw_samples = fetch_all_samples_artifact(task_ids['token_replacement']['unsupervised']['gpt2-xl']['acf_task_id'], artifact_name='experiment_data.pkl')\n",
    "# abstraction_change_rate(data), abstraction_p_increased_rate(data), data.apply(Y_cross_entropy, axis=1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['gender_steering.unsupervised.gpt2-xl.acf.pkl',\n",
       " 'token_replacement.unsupervised.gpt2-xl.acf.pkl',\n",
       " 'gender_steering.supervised.gpt2-xl.tlcf.pkl',\n",
       " 'token_replacement.supervised.gpt2-xl.tlcf.pkl',\n",
       " 'token_replacement.supervised.gpt2-xl.acf.pkl',\n",
       " 'token_replacement.unsupervised.llama-3.2-1B.acf.pkl',\n",
       " 'gender_steering.unsupervised.gpt2-xl.tlcf.pkl',\n",
       " 'token_replacement.unsupervised.gpt2-xl.tlcf.pkl',\n",
       " 'token_replacement.supervised.llama-3.2-1B.tlcf.pkl',\n",
       " 'token_replacement.supervised.llama-3.2-1B.acf.pkl',\n",
       " 'token_replacement.unsupervised.llama-3.2-1B.tlcf.pkl',\n",
       " 'gender_steering.supervised.gpt2-xl.acf.pkl']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.listdir(SAVE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 407,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gender_steering.unsupervised.gpt2-xl.acf.pkl\n",
      "token_replacement.unsupervised.gpt2-xl.acf.pkl\n",
      "gender_steering.supervised.gpt2-xl.tlcf.pkl\n",
      "token_replacement.supervised.gpt2-xl.tlcf.pkl\n",
      "token_replacement.supervised.gpt2-xl.acf.pkl\n",
      "token_replacement.unsupervised.llama-3.2-1B.acf.pkl\n",
      "gender_steering.unsupervised.gpt2-xl.tlcf.pkl\n",
      "token_replacement.unsupervised.gpt2-xl.tlcf.pkl\n",
      "token_replacement.supervised.llama-3.2-1B.tlcf.pkl\n",
      "token_replacement.supervised.llama-3.2-1B.acf.pkl\n",
      "token_replacement.unsupervised.llama-3.2-1B.tlcf.pkl\n",
      "gender_steering.supervised.gpt2-xl.acf.pkl\n"
     ]
    },
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "int_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "method",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_change_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "abstraction_p_increase_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "Y_cross_entropy",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "149d9752-c6d8-42ec-b817-a1e7ebb2660d",
       "rows": [
        [
         "0",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Abstract",
         "0.12",
         "0.984",
         "1.1974755330560936"
        ],
        [
         "1",
         "token_replacement",
         "unsupervised",
         "gpt2-xl",
         "Abstract",
         "0.268",
         "0.868",
         "0.9844405531278292"
        ],
        [
         "2",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Token-Level",
         "0.404",
         "0.588",
         "1.8875255306159668"
        ],
        [
         "3",
         "token_replacement",
         "supervised",
         "gpt2-xl",
         "Token-Level",
         "0.324",
         "0.684",
         "1.8939063353723127"
        ],
        [
         "4",
         "token_replacement",
         "supervised",
         "gpt2-xl",
         "Abstract",
         "0.024",
         "0.964",
         "1.6468261693745099"
        ],
        [
         "5",
         "token_replacement",
         "unsupervised",
         "llama-3.2-1B",
         "Abstract",
         "0.408",
         "0.752",
         "1.187844711688888"
        ],
        [
         "6",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Token-Level",
         "0.376",
         "0.736",
         "1.770031759258887"
        ],
        [
         "7",
         "token_replacement",
         "unsupervised",
         "gpt2-xl",
         "Token-Level",
         "0.54",
         "0.476",
         "1.468632410454137"
        ],
        [
         "8",
         "token_replacement",
         "supervised",
         "llama-3.2-1B",
         "Token-Level",
         "0.368",
         "0.672",
         "1.9986221133751465"
        ],
        [
         "9",
         "token_replacement",
         "supervised",
         "llama-3.2-1B",
         "Abstract",
         "0.048",
         "0.968",
         "1.769906708931869"
        ],
        [
         "10",
         "token_replacement",
         "unsupervised",
         "llama-3.2-1B",
         "Token-Level",
         "0.664",
         "0.468",
         "1.7834796989782526"
        ],
        [
         "11",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Abstract",
         "0.04",
         "0.98",
         "1.3810518091986106"
        ]
       ],
       "shape": {
        "columns": 7,
        "rows": 12
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>int_type</th>\n",
       "      <th>abstraction_type</th>\n",
       "      <th>model</th>\n",
       "      <th>method</th>\n",
       "      <th>abstraction_change_rate</th>\n",
       "      <th>abstraction_p_increase_rate</th>\n",
       "      <th>Y_cross_entropy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.120</td>\n",
       "      <td>0.984</td>\n",
       "      <td>1.197476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.268</td>\n",
       "      <td>0.868</td>\n",
       "      <td>0.984441</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.404</td>\n",
       "      <td>0.588</td>\n",
       "      <td>1.887526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.324</td>\n",
       "      <td>0.684</td>\n",
       "      <td>1.893906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.024</td>\n",
       "      <td>0.964</td>\n",
       "      <td>1.646826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.408</td>\n",
       "      <td>0.752</td>\n",
       "      <td>1.187845</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.376</td>\n",
       "      <td>0.736</td>\n",
       "      <td>1.770032</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.540</td>\n",
       "      <td>0.476</td>\n",
       "      <td>1.468632</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.368</td>\n",
       "      <td>0.672</td>\n",
       "      <td>1.998622</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.048</td>\n",
       "      <td>0.968</td>\n",
       "      <td>1.769907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.664</td>\n",
       "      <td>0.468</td>\n",
       "      <td>1.783480</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.980</td>\n",
       "      <td>1.381052</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             int_type abstraction_type         model       method  \\\n",
       "0     gender_steering     unsupervised       gpt2-xl     Abstract   \n",
       "1   token_replacement     unsupervised       gpt2-xl     Abstract   \n",
       "2     gender_steering       supervised       gpt2-xl  Token-Level   \n",
       "3   token_replacement       supervised       gpt2-xl  Token-Level   \n",
       "4   token_replacement       supervised       gpt2-xl     Abstract   \n",
       "5   token_replacement     unsupervised  llama-3.2-1B     Abstract   \n",
       "6     gender_steering     unsupervised       gpt2-xl  Token-Level   \n",
       "7   token_replacement     unsupervised       gpt2-xl  Token-Level   \n",
       "8   token_replacement       supervised  llama-3.2-1B  Token-Level   \n",
       "9   token_replacement       supervised  llama-3.2-1B     Abstract   \n",
       "10  token_replacement     unsupervised  llama-3.2-1B  Token-Level   \n",
       "11    gender_steering       supervised       gpt2-xl     Abstract   \n",
       "\n",
       "    abstraction_change_rate  abstraction_p_increase_rate  Y_cross_entropy  \n",
       "0                     0.120                        0.984         1.197476  \n",
       "1                     0.268                        0.868         0.984441  \n",
       "2                     0.404                        0.588         1.887526  \n",
       "3                     0.324                        0.684         1.893906  \n",
       "4                     0.024                        0.964         1.646826  \n",
       "5                     0.408                        0.752         1.187845  \n",
       "6                     0.376                        0.736         1.770032  \n",
       "7                     0.540                        0.476         1.468632  \n",
       "8                     0.368                        0.672         1.998622  \n",
       "9                     0.048                        0.968         1.769907  \n",
       "10                    0.664                        0.468         1.783480  \n",
       "11                    0.040                        0.980         1.381052  "
      ]
     },
     "execution_count": 407,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = []\n",
    "\n",
    "for file in os.listdir(SAVE_PATH):\n",
    "    try:\n",
    "        parsed_exp_name = parse_experiment_filename(file)\n",
    "        exp_data = load_processed_experiment(**parsed_exp_name) \n",
    "        data.append({\n",
    "            **parsed_exp_name,\n",
    "        'abstraction_change_rate': abstraction_change_rate(exp_data),\n",
    "        'abstraction_p_increase_rate': exp_data.apply(_has_abstraction_p_increased, axis=1).mean(),\n",
    "        'Y_cross_entropy': exp_data.apply(Y_cross_entropy, axis=1).mean()\n",
    "        })\n",
    "    except Exception as e:\n",
    "        print('FAILED on', file)\n",
    "        print(e)\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "method_labels = {\n",
    "    'acf': 'Abstract',\n",
    "    'tlcf': 'Token-Level'\n",
    "}\n",
    "df['method'] = df['method'].apply(lambda x: method_labels[x])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 399,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "int_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "method",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_change_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "abstraction_p_increase_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "Y_cross_entropy",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "98413c04-20e3-4867-9064-bb3a9327839e",
       "rows": [
        [
         "2",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Token-Level",
         "0.404",
         "0.588",
         "1.8875255371048607"
        ],
        [
         "11",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Abstract",
         "0.04",
         "0.98",
         "1.381051812144483"
        ]
       ],
       "shape": {
        "columns": 7,
        "rows": 2
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>int_type</th>\n",
       "      <th>abstraction_type</th>\n",
       "      <th>model</th>\n",
       "      <th>method</th>\n",
       "      <th>abstraction_change_rate</th>\n",
       "      <th>abstraction_p_increase_rate</th>\n",
       "      <th>Y_cross_entropy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.404</td>\n",
       "      <td>0.588</td>\n",
       "      <td>1.887526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.980</td>\n",
       "      <td>1.381052</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           int_type abstraction_type    model       method  \\\n",
       "2   gender_steering       supervised  gpt2-xl  Token-Level   \n",
       "11  gender_steering       supervised  gpt2-xl     Abstract   \n",
       "\n",
       "    abstraction_change_rate  abstraction_p_increase_rate  Y_cross_entropy  \n",
       "2                     0.404                        0.588         1.887526  \n",
       "11                    0.040                        0.980         1.381052  "
      ]
     },
     "execution_count": 399,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\n",
    "    # df.method == 'Abstract' &\n",
    "    (df.int_type == 'gender_steering') &\n",
    "    # (df.model == 'llama-3.2-1B') &\n",
    "    (df.model == 'gpt2-xl') &\n",
    "    (df.abstraction_type == 'supervised')\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gender Steering analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 418,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "Abstraction Type=unsupervised<br>Counterfactual Method=%{x}<br>Abstraction Change Rate=%{y}<extra></extra>",
         "legendgroup": "unsupervised",
         "marker": {
          "color": "#636efa",
          "pattern": {
           "shape": ""
          }
         },
         "name": "unsupervised",
         "offsetgroup": "unsupervised",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Abstract",
          "Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "uB6F61G4vj+q8dJNYhDYPw==",
          "dtype": "f8"
         },
         "yaxis": "y"
        },
        {
         "alignmentgroup": "True",
         "hovertemplate": "Abstraction Type=supervised<br>Counterfactual Method=%{x}<br>Abstraction Change Rate=%{y}<extra></extra>",
         "legendgroup": "supervised",
         "marker": {
          "color": "#EF553B",
          "pattern": {
           "shape": ""
          }
         },
         "name": "supervised",
         "offsetgroup": "supervised",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Token-Level",
          "Abstract"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "QmDl0CLb2T97FK5H4XqkPw==",
          "dtype": "f8"
         },
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "group",
        "height": 600,
        "legend": {
         "bgcolor": "rgba(255,255,255,0.5)",
         "font": {
          "size": 18
         },
         "title": {
          "text": "Abstraction Type"
         },
         "tracegroupgap": 0,
         "x": 0.05,
         "xanchor": "left",
         "y": 0.95,
         "yanchor": "top"
        },
        "margin": {
         "t": 60
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "width": 1000,
        "xaxis": {
         "anchor": "y",
         "domain": [
          0,
          1
         ],
         "tickfont": {
          "size": 20
         },
         "title": {
          "font": {
           "size": 20
          },
          "text": "Counterfactual Method"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "tickfont": {
          "size": 20
         },
         "title": {
          "font": {
           "size": 20
          },
          "text": "Abstraction Change Rate"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "gender_steering_df = df[df.int_type=='gender_steering']\n",
    "\n",
    "# Create a grouped bar chart.\n",
    "fig = px.bar(\n",
    "    gender_steering_df,\n",
    "    x=\"method\",\n",
    "    y=\"abstraction_change_rate\",\n",
    "    color=\"abstraction_type\",\n",
    "    barmode=\"group\",\n",
    "    labels={\n",
    "        \"method\": \"Counterfactual Method\",\n",
    "        \"abstraction_change_rate\": \"Abstraction Change Rate\",\n",
    "        \"abstraction_type\": \"Abstraction Type\"\n",
    "    },\n",
    "    # title=\"Gender Steering Abstraction Change Rate\",\n",
    "    width=1000,\n",
    "    height=600\n",
    ")\n",
    "\n",
    "# Update layout to place the legend inside the figure.\n",
    "fig.update_layout(\n",
    "    legend=dict(\n",
    "        x=0.05,      # Horizontal position (0: left, 1: right)\n",
    "        y=0.95,     # Vertical position (0: bottom, 1: top)\n",
    "        xanchor=\"left\",\n",
    "        yanchor=\"top\",\n",
    "        bgcolor=\"rgba(255,255,255,0.5)\",  # semi-transparent background for clarity\n",
    "        font=dict(size=18)  # Increase legend font size\n",
    "    )\n",
    ")\n",
    "fig.update_xaxes(tickfont=dict(size=20))\n",
    "fig.update_xaxes(title_font=dict(size=20))\n",
    "fig.update_yaxes(tickfont=dict(size=20))\n",
    "fig.update_yaxes(title_font=dict(size=20))\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "int_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "method",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_change_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "abstraction_p_increase_rate",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "771ab330-4df8-49aa-86bb-b5437f5a4cc2",
       "rows": [
        [
         "0",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Abstract",
         "0.12",
         "0.9"
        ],
        [
         "2",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Token-Level",
         "0.404",
         "0.036"
        ],
        [
         "6",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Token-Level",
         "0.376",
         "0.032"
        ],
        [
         "11",
         "gender_steering",
         "supervised",
         "gpt2-xl",
         "Abstract",
         "0.04",
         "0.892"
        ]
       ],
       "shape": {
        "columns": 6,
        "rows": 4
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>int_type</th>\n",
       "      <th>abstraction_type</th>\n",
       "      <th>model</th>\n",
       "      <th>method</th>\n",
       "      <th>abstraction_change_rate</th>\n",
       "      <th>abstraction_p_increase_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.120</td>\n",
       "      <td>0.900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.404</td>\n",
       "      <td>0.036</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.376</td>\n",
       "      <td>0.032</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.040</td>\n",
       "      <td>0.892</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           int_type abstraction_type    model       method  \\\n",
       "0   gender_steering     unsupervised  gpt2-xl     Abstract   \n",
       "2   gender_steering       supervised  gpt2-xl  Token-Level   \n",
       "6   gender_steering     unsupervised  gpt2-xl  Token-Level   \n",
       "11  gender_steering       supervised  gpt2-xl     Abstract   \n",
       "\n",
       "    abstraction_change_rate  abstraction_p_increase_rate  \n",
       "0                     0.120                        0.900  \n",
       "2                     0.404                        0.036  \n",
       "6                     0.376                        0.032  \n",
       "11                    0.040                        0.892  "
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gender_steering_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 437,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "int_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "method",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_change_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "abstraction_p_increase_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "Y_cross_entropy",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "3d7b09fc-ea93-4797-a705-f66cdd4cf4ff",
       "rows": [
        [
         "0",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Abstract",
         "0.12",
         "0.984",
         "1.1974755330560936"
        ],
        [
         "6",
         "gender_steering",
         "unsupervised",
         "gpt2-xl",
         "Token-Level",
         "0.376",
         "0.736",
         "1.770031759258887"
        ]
       ],
       "shape": {
        "columns": 7,
        "rows": 2
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>int_type</th>\n",
       "      <th>abstraction_type</th>\n",
       "      <th>model</th>\n",
       "      <th>method</th>\n",
       "      <th>abstraction_change_rate</th>\n",
       "      <th>abstraction_p_increase_rate</th>\n",
       "      <th>Y_cross_entropy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.120</td>\n",
       "      <td>0.984</td>\n",
       "      <td>1.197476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>gender_steering</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.376</td>\n",
       "      <td>0.736</td>\n",
       "      <td>1.770032</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          int_type abstraction_type    model       method  \\\n",
       "0  gender_steering     unsupervised  gpt2-xl     Abstract   \n",
       "6  gender_steering     unsupervised  gpt2-xl  Token-Level   \n",
       "\n",
       "   abstraction_change_rate  abstraction_p_increase_rate  Y_cross_entropy  \n",
       "0                    0.120                        0.984         1.197476  \n",
       "6                    0.376                        0.736         1.770032  "
      ]
     },
     "execution_count": 437,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gender_steering_df[\n",
    "    (gender_steering_df.abstraction_type=='unsupervised')\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 401,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "Abstraction Type=unsupervised<br>Counterfactual Method=%{x}<br>Y Cross Entropy=%{y}<extra></extra>",
         "legendgroup": "unsupervised",
         "marker": {
          "color": "#636efa",
          "pattern": {
           "shape": ""
          }
         },
         "name": "unsupervised",
         "offsetgroup": "unsupervised",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Abstract",
          "Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "NVaFG9wo8z+BOtaVFFL8Pw==",
          "dtype": "f8"
         },
         "yaxis": "y"
        },
        {
         "alignmentgroup": "True",
         "hovertemplate": "Abstraction Type=supervised<br>Counterfactual Method=%{x}<br>Y Cross Entropy=%{y}<extra></extra>",
         "legendgroup": "supervised",
         "marker": {
          "color": "#EF553B",
          "pattern": {
           "shape": ""
          }
         },
         "name": "supervised",
         "offsetgroup": "supervised",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Token-Level",
          "Abstract"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "8q5D+k0z/j+a3/PIyRj2Pw==",
          "dtype": "f8"
         },
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "group",
        "height": 600,
        "legend": {
         "bgcolor": "rgba(255,255,255,0.5)",
         "title": {
          "text": "Abstraction Type"
         },
         "tracegroupgap": 0,
         "x": 0.05,
         "xanchor": "left",
         "y": 0.95,
         "yanchor": "top"
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Gender Y Cross Entropy"
        },
        "width": 1000,
        "xaxis": {
         "anchor": "y",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Counterfactual Method"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Y Cross Entropy"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "gender_steering_df = df[df.int_type=='gender_steering']\n",
    "\n",
    "# Create a grouped bar chart.\n",
    "fig = px.bar(\n",
    "    gender_steering_df,\n",
    "    x=\"method\",\n",
    "    y=\"Y_cross_entropy\",\n",
    "    color=\"abstraction_type\",\n",
    "    barmode=\"group\",\n",
    "    labels={\n",
    "        \"method\": \"Counterfactual Method\",\n",
    "        \"Y_cross_entropy\": \"Y Cross Entropy\",\n",
    "        \"abstraction_type\": \"Abstraction Type\"\n",
    "    },\n",
    "    title=\"Gender Y Cross Entropy\",\n",
    "    width=1000,\n",
    "    height=600\n",
    ")\n",
    "\n",
    "# Update layout to place the legend inside the figure.\n",
    "fig.update_layout(\n",
    "    legend=dict(\n",
    "        x=0.05,      # Horizontal position (0: left, 1: right)\n",
    "        y=0.95,     # Vertical position (0: bottom, 1: top)\n",
    "        xanchor=\"left\",\n",
    "        yanchor=\"top\",\n",
    "        bgcolor=\"rgba(255,255,255,0.5)\"  # semi-transparent background for clarity\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Token Level Replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 408,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_replacement_df = df[df.int_type=='token_replacement']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 419,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/zp/5fd9lh0x5tgcx_03zmfmm2sm0000gn/T/ipykernel_55574/2433412493.py:4: SettingWithCopyWarning:\n",
      "\n",
      "\n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: XXXX\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "Model=gpt2-xl<br>Counterfactual Method=%{x}<br>Abstraction Change Rate=%{y}<extra></extra>",
         "legendgroup": "gpt2-xl",
         "marker": {
          "color": "#1f77b4",
          "pattern": {
           "shape": ""
          }
         },
         "name": "gpt2-xl",
         "offsetgroup": "gpt2-xl",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Unsupervised Abstract",
          "Supervised Token-Level",
          "Supervised Abstract",
          "Unsupervised Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "9P3UeOkm0T8j2/l+arzUP/p+arx0k5g/SOF6FK5H4T8=",
          "dtype": "f8"
         },
         "yaxis": "y"
        },
        {
         "alignmentgroup": "True",
         "hovertemplate": "Model=llama-3.2-1B<br>Counterfactual Method=%{x}<br>Abstraction Change Rate=%{y}<extra></extra>",
         "legendgroup": "llama-3.2-1B",
         "marker": {
          "color": "#0c4c8a",
          "pattern": {
           "shape": ""
          }
         },
         "name": "llama-3.2-1B",
         "offsetgroup": "llama-3.2-1B",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Unsupervised Abstract",
          "Supervised Token-Level",
          "Supervised Abstract",
          "Unsupervised Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "6SYxCKwc2j9aZDvfT43XP/p+arx0k6g/c2iR7Xw/5T8=",
          "dtype": "f8"
         },
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "group",
        "height": 600,
        "legend": {
         "bgcolor": "rgba(255,255,255,0.5)",
         "font": {
          "size": 18
         },
         "title": {
          "text": "Model"
         },
         "tracegroupgap": 0,
         "x": 0.05,
         "xanchor": "left",
         "y": 0.95,
         "yanchor": "top"
        },
        "margin": {
         "t": 60
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "width": 1300,
        "xaxis": {
         "anchor": "y",
         "categoryarray": [
          "Supervised Abstract",
          "Supervised Token-Level",
          "Unsupervised Abstract",
          "Unsupervised Token-Level"
         ],
         "categoryorder": "array",
         "domain": [
          0,
          1
         ],
         "tickfont": {
          "size": 20
         },
         "title": {
          "font": {
           "size": 20
          },
          "text": "Counterfactual Method"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "tickfont": {
          "size": 20
         },
         "title": {
          "font": {
           "size": 20
          },
          "text": "Abstraction Change Rate"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "# Create a new Group column for the x-axis ordering.\n",
    "token_replacement_df['Group'] = token_replacement_df['abstraction_type'].str.title() + ' ' + token_replacement_df['method']\n",
    "\n",
    "# Specify the desired x-axis order (ACF groups on the left, TLCF groups on the right).\n",
    "# group_order = [\"Supervised Abstract\", \"Unsupervised Abstract\", \"Supervised Token-Level\", \"Unsupervised Token-Level\"]\n",
    "group_order = [\"Supervised Abstract\", \"Supervised Token-Level\", \"Unsupervised Abstract\", \"Unsupervised Token-Level\"]\n",
    "\n",
    "# Create the grouped bar chart with blue color scheme\n",
    "fig = px.bar(\n",
    "    token_replacement_df,\n",
    "    x='Group',\n",
    "    y='abstraction_change_rate',\n",
    "    color='model',\n",
    "    barmode='group',\n",
    "    category_orders={'Group': group_order},\n",
    "    color_discrete_map={\n",
    "        # Replace these with your actual model names from the dataframe\n",
    "        token_replacement_df['model'].unique()[0]: \"#1f77b4\",  # Lighter blue\n",
    "        token_replacement_df['model'].unique()[1]: \"#0c4c8a\"   # Darker blue\n",
    "    },\n",
    "    labels={\n",
    "        'Group': 'Counterfactual Method',\n",
    "        'abstraction_change_rate': 'Abstraction Change Rate',\n",
    "        'model': 'Model'\n",
    "    },\n",
    "    # title='Token Replacement Experiment: Abstraction Change Rate',\n",
    "    width=1300,\n",
    "    height=600\n",
    ")\n",
    "# Update layout to place the legend inside the figure.\n",
    "fig.update_layout(\n",
    "    legend=dict(\n",
    "        x=0.05,      # Horizontal position (0: left, 1: right)\n",
    "        y=0.95,     # Vertical position (0: bottom, 1: top)\n",
    "        xanchor=\"left\",\n",
    "        yanchor=\"top\",\n",
    "        bgcolor=\"rgba(255,255,255,0.5)\",  # semi-transparent background for clarity\n",
    "        font=dict(size=18)  # Increase legend font size\n",
    "    )\n",
    ")\n",
    "fig.update_xaxes(tickfont=dict(size=20))\n",
    "fig.update_xaxes(title_font=dict(size=20))\n",
    "fig.update_yaxes(tickfont=dict(size=20))\n",
    "fig.update_yaxes(title_font=dict(size=20))\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "int_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_type",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "method",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "abstraction_change_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "abstraction_p_increase_rate",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "Y_cross_entropy",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "Group",
         "rawType": "object",
         "type": "string"
        }
       ],
       "conversionMethod": "pd.DataFrame",
       "ref": "7ea09826-3b48-4ba0-b819-95a37c26e46a",
       "rows": [
        [
         "1",
         "token_replacement",
         "unsupervised",
         "gpt2-xl",
         "Abstract",
         "0.268",
         "0.868",
         "0.9844405531278292",
         "Unsupervised Abstract"
        ],
        [
         "3",
         "token_replacement",
         "supervised",
         "gpt2-xl",
         "Token-Level",
         "0.324",
         "0.684",
         "1.8939063353723127",
         "Supervised Token-Level"
        ],
        [
         "4",
         "token_replacement",
         "supervised",
         "gpt2-xl",
         "Abstract",
         "0.024",
         "0.964",
         "1.6468261693745099",
         "Supervised Abstract"
        ],
        [
         "5",
         "token_replacement",
         "unsupervised",
         "llama-3.2-1B",
         "Abstract",
         "0.408",
         "0.752",
         "1.187844711688888",
         "Unsupervised Abstract"
        ],
        [
         "7",
         "token_replacement",
         "unsupervised",
         "gpt2-xl",
         "Token-Level",
         "0.54",
         "0.476",
         "1.468632410454137",
         "Unsupervised Token-Level"
        ],
        [
         "8",
         "token_replacement",
         "supervised",
         "llama-3.2-1B",
         "Token-Level",
         "0.368",
         "0.672",
         "1.9986221133751465",
         "Supervised Token-Level"
        ],
        [
         "9",
         "token_replacement",
         "supervised",
         "llama-3.2-1B",
         "Abstract",
         "0.048",
         "0.968",
         "1.769906708931869",
         "Supervised Abstract"
        ],
        [
         "10",
         "token_replacement",
         "unsupervised",
         "llama-3.2-1B",
         "Token-Level",
         "0.664",
         "0.468",
         "1.7834796989782526",
         "Unsupervised Token-Level"
        ]
       ],
       "shape": {
        "columns": 8,
        "rows": 8
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>int_type</th>\n",
       "      <th>abstraction_type</th>\n",
       "      <th>model</th>\n",
       "      <th>method</th>\n",
       "      <th>abstraction_change_rate</th>\n",
       "      <th>abstraction_p_increase_rate</th>\n",
       "      <th>Y_cross_entropy</th>\n",
       "      <th>Group</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.268</td>\n",
       "      <td>0.868</td>\n",
       "      <td>0.984441</td>\n",
       "      <td>Unsupervised Abstract</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.324</td>\n",
       "      <td>0.684</td>\n",
       "      <td>1.893906</td>\n",
       "      <td>Supervised Token-Level</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.024</td>\n",
       "      <td>0.964</td>\n",
       "      <td>1.646826</td>\n",
       "      <td>Supervised Abstract</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.408</td>\n",
       "      <td>0.752</td>\n",
       "      <td>1.187845</td>\n",
       "      <td>Unsupervised Abstract</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>gpt2-xl</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.540</td>\n",
       "      <td>0.476</td>\n",
       "      <td>1.468632</td>\n",
       "      <td>Unsupervised Token-Level</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.368</td>\n",
       "      <td>0.672</td>\n",
       "      <td>1.998622</td>\n",
       "      <td>Supervised Token-Level</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>supervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Abstract</td>\n",
       "      <td>0.048</td>\n",
       "      <td>0.968</td>\n",
       "      <td>1.769907</td>\n",
       "      <td>Supervised Abstract</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>token_replacement</td>\n",
       "      <td>unsupervised</td>\n",
       "      <td>llama-3.2-1B</td>\n",
       "      <td>Token-Level</td>\n",
       "      <td>0.664</td>\n",
       "      <td>0.468</td>\n",
       "      <td>1.783480</td>\n",
       "      <td>Unsupervised Token-Level</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             int_type abstraction_type         model       method  \\\n",
       "1   token_replacement     unsupervised       gpt2-xl     Abstract   \n",
       "3   token_replacement       supervised       gpt2-xl  Token-Level   \n",
       "4   token_replacement       supervised       gpt2-xl     Abstract   \n",
       "5   token_replacement     unsupervised  llama-3.2-1B     Abstract   \n",
       "7   token_replacement     unsupervised       gpt2-xl  Token-Level   \n",
       "8   token_replacement       supervised  llama-3.2-1B  Token-Level   \n",
       "9   token_replacement       supervised  llama-3.2-1B     Abstract   \n",
       "10  token_replacement     unsupervised  llama-3.2-1B  Token-Level   \n",
       "\n",
       "    abstraction_change_rate  abstraction_p_increase_rate  Y_cross_entropy  \\\n",
       "1                     0.268                        0.868         0.984441   \n",
       "3                     0.324                        0.684         1.893906   \n",
       "4                     0.024                        0.964         1.646826   \n",
       "5                     0.408                        0.752         1.187845   \n",
       "7                     0.540                        0.476         1.468632   \n",
       "8                     0.368                        0.672         1.998622   \n",
       "9                     0.048                        0.968         1.769907   \n",
       "10                    0.664                        0.468         1.783480   \n",
       "\n",
       "                       Group  \n",
       "1      Unsupervised Abstract  \n",
       "3     Supervised Token-Level  \n",
       "4        Supervised Abstract  \n",
       "5      Unsupervised Abstract  \n",
       "7   Unsupervised Token-Level  \n",
       "8     Supervised Token-Level  \n",
       "9        Supervised Abstract  \n",
       "10  Unsupervised Token-Level  "
      ]
     },
     "execution_count": 420,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_replacement_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 435,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10    0.468\n",
       "Name: abstraction_p_increase_rate, dtype: float64"
      ]
     },
     "execution_count": 435,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_replacement_df[\n",
    "    (token_replacement_df.abstraction_type=='unsupervised') & \n",
    "    # (token_replacement_df.model=='gpt2-xl') & \n",
    "    (token_replacement_df.model=='llama-3.2-1B') & \n",
    "    # (token_replacement_df.method=='Abstract')  \n",
    "    (token_replacement_df.method=='Token-Level')\n",
    "].abstraction_p_increase_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 409,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/zp/5fd9lh0x5tgcx_03zmfmm2sm0000gn/T/ipykernel_55574/3432054859.py:4: SettingWithCopyWarning:\n",
      "\n",
      "\n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: XXXX\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.plotly.v1+json": {
       "config": {
        "plotlyServerURL": "XXXX"
       },
       "data": [
        {
         "alignmentgroup": "True",
         "hovertemplate": "Model=gpt2-xl<br>Counterfactual Method=%{x}<br>Y Cross Entropy=%{y}<extra></extra>",
         "legendgroup": "gpt2-xl",
         "marker": {
          "color": "#1f77b4",
          "pattern": {
           "shape": ""
          }
         },
         "name": "gpt2-xl",
         "offsetgroup": "gpt2-xl",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Unsupervised Abstract",
          "Supervised Token-Level",
          "Supervised Abstract",
          "Unsupervised Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "IUmReYmA7z/Zx8G6cE3+P1CRumVmWfo/TfDLsoR/9z8=",
          "dtype": "f8"
         },
         "yaxis": "y"
        },
        {
         "alignmentgroup": "True",
         "hovertemplate": "Model=llama-3.2-1B<br>Counterfactual Method=%{x}<br>Y Cross Entropy=%{y}<extra></extra>",
         "legendgroup": "llama-3.2-1B",
         "marker": {
          "color": "#0c4c8a",
          "pattern": {
           "shape": ""
          }
         },
         "name": "llama-3.2-1B",
         "offsetgroup": "llama-3.2-1B",
         "orientation": "v",
         "showlegend": true,
         "textposition": "auto",
         "type": "bar",
         "x": [
          "Unsupervised Abstract",
          "Supervised Token-Level",
          "Supervised Abstract",
          "Unsupervised Token-Level"
         ],
         "xaxis": "x",
         "y": {
          "bdata": "muLWdGkB8z93I2AuW/r/P3pVfbKJUfw/dxBDAiKJ/D8=",
          "dtype": "f8"
         },
         "yaxis": "y"
        }
       ],
       "layout": {
        "barmode": "group",
        "height": 800,
        "legend": {
         "bgcolor": "rgba(255,255,255,0.5)",
         "title": {
          "text": "Model"
         },
         "tracegroupgap": 0,
         "x": 0.05,
         "xanchor": "left",
         "y": 0.95,
         "yanchor": "top"
        },
        "template": {
         "data": {
          "bar": [
           {
            "error_x": {
             "color": "#2a3f5f"
            },
            "error_y": {
             "color": "#2a3f5f"
            },
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "bar"
           }
          ],
          "barpolar": [
           {
            "marker": {
             "line": {
              "color": "#E5ECF6",
              "width": 0.5
             },
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "barpolar"
           }
          ],
          "carpet": [
           {
            "aaxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "baxis": {
             "endlinecolor": "#2a3f5f",
             "gridcolor": "white",
             "linecolor": "white",
             "minorgridcolor": "white",
             "startlinecolor": "#2a3f5f"
            },
            "type": "carpet"
           }
          ],
          "choropleth": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "choropleth"
           }
          ],
          "contour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "contour"
           }
          ],
          "contourcarpet": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "contourcarpet"
           }
          ],
          "heatmap": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "heatmap"
           }
          ],
          "histogram": [
           {
            "marker": {
             "pattern": {
              "fillmode": "overlay",
              "size": 10,
              "solidity": 0.2
             }
            },
            "type": "histogram"
           }
          ],
          "histogram2d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2d"
           }
          ],
          "histogram2dcontour": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "histogram2dcontour"
           }
          ],
          "mesh3d": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "type": "mesh3d"
           }
          ],
          "parcoords": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "parcoords"
           }
          ],
          "pie": [
           {
            "automargin": true,
            "type": "pie"
           }
          ],
          "scatter": [
           {
            "fillpattern": {
             "fillmode": "overlay",
             "size": 10,
             "solidity": 0.2
            },
            "type": "scatter"
           }
          ],
          "scatter3d": [
           {
            "line": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatter3d"
           }
          ],
          "scattercarpet": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattercarpet"
           }
          ],
          "scattergeo": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergeo"
           }
          ],
          "scattergl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattergl"
           }
          ],
          "scattermap": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermap"
           }
          ],
          "scattermapbox": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scattermapbox"
           }
          ],
          "scatterpolar": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolar"
           }
          ],
          "scatterpolargl": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterpolargl"
           }
          ],
          "scatterternary": [
           {
            "marker": {
             "colorbar": {
              "outlinewidth": 0,
              "ticks": ""
             }
            },
            "type": "scatterternary"
           }
          ],
          "surface": [
           {
            "colorbar": {
             "outlinewidth": 0,
             "ticks": ""
            },
            "colorscale": [
             [
              0,
              "#0d0887"
             ],
             [
              0.1111111111111111,
              "#46039f"
             ],
             [
              0.2222222222222222,
              "#7201a8"
             ],
             [
              0.3333333333333333,
              "#9c179e"
             ],
             [
              0.4444444444444444,
              "#bd3786"
             ],
             [
              0.5555555555555556,
              "#d8576b"
             ],
             [
              0.6666666666666666,
              "#ed7953"
             ],
             [
              0.7777777777777778,
              "#fb9f3a"
             ],
             [
              0.8888888888888888,
              "#fdca26"
             ],
             [
              1,
              "#f0f921"
             ]
            ],
            "type": "surface"
           }
          ],
          "table": [
           {
            "cells": {
             "fill": {
              "color": "#EBF0F8"
             },
             "line": {
              "color": "white"
             }
            },
            "header": {
             "fill": {
              "color": "#C8D4E3"
             },
             "line": {
              "color": "white"
             }
            },
            "type": "table"
           }
          ]
         },
         "layout": {
          "annotationdefaults": {
           "arrowcolor": "#2a3f5f",
           "arrowhead": 0,
           "arrowwidth": 1
          },
          "autotypenumbers": "strict",
          "coloraxis": {
           "colorbar": {
            "outlinewidth": 0,
            "ticks": ""
           }
          },
          "colorscale": {
           "diverging": [
            [
             0,
             "#8e0152"
            ],
            [
             0.1,
             "#c51b7d"
            ],
            [
             0.2,
             "#de77ae"
            ],
            [
             0.3,
             "#f1b6da"
            ],
            [
             0.4,
             "#fde0ef"
            ],
            [
             0.5,
             "#f7f7f7"
            ],
            [
             0.6,
             "#e6f5d0"
            ],
            [
             0.7,
             "#b8e186"
            ],
            [
             0.8,
             "#7fbc41"
            ],
            [
             0.9,
             "#4d9221"
            ],
            [
             1,
             "#276419"
            ]
           ],
           "sequential": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ],
           "sequentialminus": [
            [
             0,
             "#0d0887"
            ],
            [
             0.1111111111111111,
             "#46039f"
            ],
            [
             0.2222222222222222,
             "#7201a8"
            ],
            [
             0.3333333333333333,
             "#9c179e"
            ],
            [
             0.4444444444444444,
             "#bd3786"
            ],
            [
             0.5555555555555556,
             "#d8576b"
            ],
            [
             0.6666666666666666,
             "#ed7953"
            ],
            [
             0.7777777777777778,
             "#fb9f3a"
            ],
            [
             0.8888888888888888,
             "#fdca26"
            ],
            [
             1,
             "#f0f921"
            ]
           ]
          },
          "colorway": [
           "#636efa",
           "#EF553B",
           "#00cc96",
           "#ab63fa",
           "#FFA15A",
           "#19d3f3",
           "#FF6692",
           "#B6E880",
           "#FF97FF",
           "#FECB52"
          ],
          "font": {
           "color": "#2a3f5f"
          },
          "geo": {
           "bgcolor": "white",
           "lakecolor": "white",
           "landcolor": "#E5ECF6",
           "showlakes": true,
           "showland": true,
           "subunitcolor": "white"
          },
          "hoverlabel": {
           "align": "left"
          },
          "hovermode": "closest",
          "mapbox": {
           "style": "light"
          },
          "paper_bgcolor": "white",
          "plot_bgcolor": "#E5ECF6",
          "polar": {
           "angularaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "radialaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "scene": {
           "xaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "yaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           },
           "zaxis": {
            "backgroundcolor": "#E5ECF6",
            "gridcolor": "white",
            "gridwidth": 2,
            "linecolor": "white",
            "showbackground": true,
            "ticks": "",
            "zerolinecolor": "white"
           }
          },
          "shapedefaults": {
           "line": {
            "color": "#2a3f5f"
           }
          },
          "ternary": {
           "aaxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "baxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           },
           "bgcolor": "#E5ECF6",
           "caxis": {
            "gridcolor": "white",
            "linecolor": "white",
            "ticks": ""
           }
          },
          "title": {
           "x": 0.05
          },
          "xaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          },
          "yaxis": {
           "automargin": true,
           "gridcolor": "white",
           "linecolor": "white",
           "ticks": "",
           "title": {
            "standoff": 15
           },
           "zerolinecolor": "white",
           "zerolinewidth": 2
          }
         }
        },
        "title": {
         "text": "Token Replacement Experiment: Y CE"
        },
        "width": 1200,
        "xaxis": {
         "anchor": "y",
         "categoryarray": [
          "Supervised Abstract",
          "Supervised Token-Level",
          "Unsupervised Abstract",
          "Unsupervised Token-Level"
         ],
         "categoryorder": "array",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Counterfactual Method"
         }
        },
        "yaxis": {
         "anchor": "x",
         "domain": [
          0,
          1
         ],
         "title": {
          "text": "Y Cross Entropy"
         }
        }
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "# Create a new Group column for the x-axis ordering.\n",
    "token_replacement_df['Group'] = token_replacement_df['abstraction_type'].str.title() + ' ' + token_replacement_df['method']\n",
    "\n",
    "# Specify the desired x-axis order (ACF groups on the left, TLCF groups on the right).\n",
    "# group_order = [\"Supervised Abstract\", \"Unsupervised Abstract\", \"Supervised Token-Level\", \"Unsupervised Token-Level\"]\n",
    "group_order = [\"Supervised Abstract\", \"Supervised Token-Level\", \"Unsupervised Abstract\", \"Unsupervised Token-Level\"]\n",
    "\n",
    "# Create the grouped bar chart with blue color scheme\n",
    "fig = px.bar(\n",
    "    token_replacement_df,\n",
    "    x='Group',\n",
    "    y='Y_cross_entropy',\n",
    "    color='model',\n",
    "    barmode='group',\n",
    "    category_orders={'Group': group_order},\n",
    "    color_discrete_map={\n",
    "        # Replace these with your actual model names from the dataframe\n",
    "        token_replacement_df['model'].unique()[0]: \"#1f77b4\",  # Lighter blue\n",
    "        token_replacement_df['model'].unique()[1]: \"#0c4c8a\"   # Darker blue\n",
    "    },\n",
    "    labels={\n",
    "        'Group': 'Counterfactual Method',\n",
    "        'Y_cross_entropy': 'Y Cross Entropy',\n",
    "        'model': 'Model'\n",
    "    },\n",
    "    title='Token Replacement Experiment: Y CE',\n",
    "    width=1200,\n",
    "    height=800\n",
    ")\n",
    "# Update layout to place the legend inside the figure.\n",
    "fig.update_layout(\n",
    "    legend=dict(\n",
    "        x=0.05,      # Horizontal position (0: left, 1: right)\n",
    "        y=0.95,     # Vertical position (0: bottom, 1: top)\n",
    "        xanchor=\"left\",\n",
    "        yanchor=\"top\",\n",
    "        bgcolor=\"rgba(255,255,255,0.5)\"  # semi-transparent background for clarity\n",
    "    )\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Some Case studies "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gender_steering.supervised.gpt2-xl.acf.pkl\n",
      "gender_steering.supervised.gpt2-xl.tlcf.pkl\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/ep/Documents/research/abstract_counterfactuals_paper/abstract_counterfactuals/venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning:\n",
      "\n",
      "IProgress not found. Please update jupyter and ipywidgets. See XXXX\n",
      "\n"
     ]
    }
   ],
   "source": [
    "acf = load_processed_experiment(\n",
    "    'gender_steering',\n",
    "    'supervised',\n",
    "    'gpt2-xl',\n",
    "    'acf'\n",
    ")\n",
    "tlcf = load_processed_experiment(\n",
    "    'gender_steering',\n",
    "    'supervised',\n",
    "    'gpt2-xl',\n",
    "    'tlcf'\n",
    ")\n",
    "\n",
    "from abstract_cf.text_generation.learned_abstraction import LearnedAbstractionPipeline\n",
    "abstraction = LearnedAbstractionPipeline.load(\n",
    "    supervised_abstraction_paths['gender_steering'], \n",
    "    device='mps'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "same_observation_mask = tlcf.observed_abstraction_id == acf.observed_abstraction_id\n",
    "acf_matching, tlcf_matching = acf[same_observation_mask], tlcf[same_observation_mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we want to check which (if any) of the counterfactual samples are interesting in a row \n",
    "# the conditions for that are the following \n",
    "# cf sample abstraction id is the same as the observed abstraction id\n",
    "# cf samle pronoun category is `only_female`\n",
    "# factual samples are mostly `only_male` (overall)\n",
    "\n",
    "def is_interesting(row: pd.Series):\n",
    "    factual_pronoun_categories = row.factual_pronoun_categories\n",
    "    factual_only_male_rate = sum(1 for p in factual_pronoun_categories if p == 'only_male') / len(factual_pronoun_categories)\n",
    "    if factual_only_male_rate > 0.5:\n",
    "        return None\n",
    "\n",
    "    cf_abstraction_ids = row.cf_abstraction_ids\n",
    "    cf_pronoun_categories = row.cf_pronoun_categories\n",
    "    interesting = []\n",
    "    for i, (cf_abstraction_id, cf_pronoun_category) in enumerate(zip(cf_abstraction_ids, cf_pronoun_categories)):\n",
    "        if cf_abstraction_id == row.observed_abstraction_id and cf_pronoun_category == 'only_female':\n",
    "            interesting.append(i)\n",
    "    return interesting\n",
    "\n",
    "\n",
    "acf_interesting = acf_matching.apply(is_interesting, axis=1).dropna()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we want to find failure cases of tlcf method\n",
    "# we want:\n",
    "# the factual pronoun category to be `only_male`\n",
    "# the cf pronoun category to be `only_female` \n",
    "# the cf abstraction id to be different from the observed abstraction id (overall)\n",
    "def tlcf_interesting(row: pd.Series):\n",
    "    if row.factual_pronoun_category != 'only_male':\n",
    "        return None\n",
    "    cf_abstraction_ids = row.cf_abstraction_ids\n",
    "    cf_pronoun_categories = row.cf_pronoun_categories\n",
    "    interesting = []\n",
    "    for i, (cf_abstraction_id, cf_pronoun_category) in enumerate(zip(cf_abstraction_ids, cf_pronoun_categories)):\n",
    "        if cf_abstraction_id != row.observed_abstraction_id and cf_pronoun_category == 'only_female':\n",
    "            interesting.append(i)\n",
    "    if len(interesting) == 0:\n",
    "        return None\n",
    "    return interesting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sample_id\n",
       "3                                                [9, 29]\n",
       "7                                                   [18]\n",
       "18                                              [12, 15]\n",
       "20                                           [4, 14, 21]\n",
       "28                       [8, 13, 15, 16, 18, 19, 20, 24]\n",
       "29     [0, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 15, 16, 17...\n",
       "30                                  [10, 11, 14, 25, 27]\n",
       "33                                                  [25]\n",
       "38                                 [3, 6, 8, 17, 24, 27]\n",
       "42                                          [15, 17, 23]\n",
       "51                            [3, 4, 10, 12, 17, 26, 27]\n",
       "52                                           [4, 17, 27]\n",
       "64                   [5, 11, 12, 13, 16, 19, 20, 22, 27]\n",
       "67              [1, 6, 7, 8, 10, 11, 15, 20, 22, 25, 29]\n",
       "72                                               [0, 22]\n",
       "73         [0, 8, 9, 11, 14, 16, 19, 20, 22, 23, 24, 25]\n",
       "79     [0, 1, 2, 3, 6, 8, 10, 11, 15, 18, 21, 23, 26,...\n",
       "80                           [3, 13, 16, 20, 22, 24, 28]\n",
       "83                                               [8, 25]\n",
       "86      [1, 2, 6, 8, 10, 12, 16, 18, 19, 20, 25, 26, 29]\n",
       "87                           [5, 6, 7, 8, 9, 15, 24, 29]\n",
       "88                                [0, 7, 19, 20, 21, 24]\n",
       "95                                       [9, 11, 21, 27]\n",
       "102    [1, 2, 3, 4, 5, 7, 11, 13, 15, 16, 19, 20, 21,...\n",
       "105                    [1, 2, 9, 17, 18, 20, 23, 27, 28]\n",
       "106    [0, 1, 5, 7, 11, 14, 15, 16, 17, 18, 19, 20, 2...\n",
       "109                                   [3, 4, 18, 20, 29]\n",
       "111            [6, 7, 8, 12, 14, 19, 23, 24, 27, 28, 29]\n",
       "116                                             [11, 16]\n",
       "119                                                 [21]\n",
       "125                 [1, 2, 4, 5, 10, 13, 15, 16, 19, 24]\n",
       "128                                     [10, 14, 20, 25]\n",
       "141                                              [8, 28]\n",
       "144                                 [11, 13, 23, 25, 28]\n",
       "148                            [3, 5, 8, 10, 11, 12, 16]\n",
       "149    [1, 2, 5, 6, 7, 12, 13, 16, 17, 19, 20, 21, 27...\n",
       "152                            [0, 1, 9, 14, 19, 24, 29]\n",
       "158                                       [3, 7, 10, 25]\n",
       "182                                          [1, 11, 28]\n",
       "186                                                 [13]\n",
       "194                                   [4, 9, 11, 14, 17]\n",
       "200                       [3, 9, 11, 17, 20, 21, 28, 29]\n",
       "203                                [1, 5, 6, 19, 23, 25]\n",
       "205                                          [7, 18, 26]\n",
       "210                                       [5, 9, 16, 28]\n",
       "212                               [6, 7, 10, 17, 19, 22]\n",
       "216                                          [7, 16, 18]\n",
       "224    [1, 2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15, 17, ...\n",
       "237    [0, 1, 4, 6, 7, 11, 13, 16, 18, 19, 20, 22, 24...\n",
       "238                               [2, 6, 18, 19, 22, 29]\n",
       "244                                       [0, 8, 22, 26]\n",
       "245                                              [3, 12]\n",
       "246                            [0, 1, 6, 11, 14, 23, 25]\n",
       "247                               [0, 6, 12, 15, 19, 25]\n",
       "dtype: object"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tlcf_interesting = tlcf_matching.apply(tlcf_interesting, axis=1).dropna()\n",
    "tlcf_interesting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "sample_id\n",
       "0      [0, 1, 2, 3, 4, 5, 6, 8, 10, 14, 17, 19, 20, 2...\n",
       "1      [1, 3, 5, 10, 11, 12, 13, 15, 16, 17, 22, 24, ...\n",
       "9                                       [10, 19, 21, 27]\n",
       "12     [1, 3, 5, 9, 10, 11, 12, 15, 16, 17, 23, 25, 2...\n",
       "14     [0, 1, 3, 4, 5, 8, 9, 10, 11, 12, 14, 16, 17, ...\n",
       "15                                                    []\n",
       "17             [0, 4, 6, 10, 12, 13, 14, 15, 17, 19, 24]\n",
       "27                                [5, 6, 11, 13, 23, 28]\n",
       "39                           [4, 13, 14, 17, 24, 25, 27]\n",
       "41     [0, 1, 4, 5, 6, 8, 9, 10, 11, 14, 15, 16, 17, ...\n",
       "46     [0, 1, 3, 5, 7, 8, 9, 10, 11, 12, 14, 15, 16, ...\n",
       "49                                  [4, 5, 8, 9, 19, 28]\n",
       "50     [1, 2, 5, 6, 8, 9, 10, 12, 13, 15, 16, 17, 18,...\n",
       "53                   [8, 12, 17, 19, 22, 23, 25, 26, 27]\n",
       "68                             [1, 5, 6, 12, 18, 20, 21]\n",
       "71     [1, 2, 3, 6, 11, 13, 14, 15, 16, 17, 18, 21, 2...\n",
       "74     [1, 2, 3, 5, 8, 9, 10, 12, 17, 18, 19, 20, 22,...\n",
       "89     [1, 2, 5, 6, 8, 9, 10, 11, 12, 13, 14, 17, 18,...\n",
       "90                                               [9, 22]\n",
       "91     [1, 2, 3, 8, 10, 11, 12, 14, 15, 16, 19, 20, 2...\n",
       "104    [0, 1, 2, 3, 5, 6, 8, 9, 10, 11, 13, 14, 15, 1...\n",
       "113    [0, 1, 2, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 25]\n",
       "117    [0, 1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 16, 17, 1...\n",
       "120                          [2, 4, 7, 8, 9, 12, 19, 28]\n",
       "121    [4, 10, 11, 13, 14, 16, 17, 18, 22, 23, 25, 28...\n",
       "123    [0, 1, 4, 5, 6, 7, 8, 10, 12, 13, 14, 15, 18, ...\n",
       "124    [0, 1, 2, 3, 4, 5, 7, 9, 12, 15, 16, 20, 24, 2...\n",
       "128    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...\n",
       "129    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14,...\n",
       "130    [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15...\n",
       "133    [0, 1, 2, 4, 5, 7, 8, 9, 10, 12, 13, 15, 17, 1...\n",
       "134    [0, 1, 3, 4, 5, 6, 8, 9, 10, 11, 13, 15, 16, 1...\n",
       "135    [1, 2, 4, 5, 6, 7, 8, 9, 12, 13, 14, 16, 17, 1...\n",
       "139    [0, 1, 2, 6, 8, 9, 12, 13, 14, 16, 17, 18, 19,...\n",
       "145    [0, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 1...\n",
       "146    [0, 2, 3, 4, 5, 7, 8, 9, 13, 14, 16, 17, 20, 2...\n",
       "148                                             [21, 26]\n",
       "151    [1, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 15, 16, 1...\n",
       "155    [1, 2, 8, 9, 12, 14, 15, 16, 18, 19, 20, 21, 2...\n",
       "164    [0, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 15, 1...\n",
       "165    [1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 1...\n",
       "169    [3, 4, 5, 6, 7, 8, 9, 10, 13, 14, 18, 19, 20, ...\n",
       "171    [0, 1, 3, 4, 6, 7, 10, 11, 12, 13, 14, 15, 16,...\n",
       "176                              [5, 10, 11, 12, 16, 19]\n",
       "180    [1, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, ...\n",
       "181                         [4, 5, 7, 8, 10, 16, 17, 22]\n",
       "188    [1, 2, 3, 4, 7, 8, 10, 12, 14, 17, 19, 20, 21,...\n",
       "190                        [5, 6, 8, 10, 11, 17, 20, 24]\n",
       "191                               [1, 5, 12, 15, 23, 24]\n",
       "198    [0, 1, 2, 3, 4, 5, 8, 9, 11, 12, 13, 14, 15, 1...\n",
       "199    [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...\n",
       "201                                             [10, 16]\n",
       "207    [0, 1, 2, 3, 5, 6, 11, 12, 14, 17, 18, 19, 21,...\n",
       "209    [1, 2, 4, 6, 8, 9, 10, 11, 12, 13, 15, 18, 19,...\n",
       "215    [0, 1, 6, 7, 8, 10, 11, 12, 14, 17, 19, 20, 22...\n",
       "217    [1, 2, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 1...\n",
       "219    [0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 15, 16...\n",
       "230    [2, 3, 6, 7, 8, 9, 10, 11, 14, 15, 16, 17, 18,...\n",
       "240                [0, 7, 8, 13, 15, 19, 21, 22, 26, 27]\n",
       "243                                                 [20]\n",
       "dtype: object"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acf_interesting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "# both interesting are sample ids that are present in both acf_interesting and tlcf_interesting\n",
    "both_interesting = acf_interesting.index.intersection(tlcf_interesting.index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index([128, 148], dtype='int64', name='sample_id')"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "both_interesting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Also a traditional reporter, this New Jersey native went on to do more than 40 stories for the News of the World, then founded the news website MuckRock to crowdsource information and ultimately went on to co-found the tech-oriented think tank, the Future of Privacy Forum. Now he's helping others get their stories out by sharing his own. He's a full member of the PRWeb team, as we are proud to be\n",
      "11\n",
      "journalist\n"
     ]
    }
   ],
   "source": [
    "# factual\n",
    "i = 128\n",
    "print(tlcf_matching.loc[i].factual_samples) \n",
    "print(tlcf_matching.loc[i].observed_abstraction_id.item())\n",
    "print(abstraction.id_to_label[tlcf_matching.loc[i].observed_abstraction_id.item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Also a traditional reporter, this New Jersey native went to Boston College and USC before earning a law degree. She is currently managing editor of the weekly column \"Sex and Shared Desires,\" and occasionally contributes to \"TribLive\" on the Huffington Post, LGBT Nation!\n",
      "2\n",
      "attorney\n"
     ]
    }
   ],
   "source": [
    "cf_sample_id = tlcf_interesting.loc[i][0]\n",
    "print(tlcf_matching.loc[i].cf_samples[cf_sample_id])\n",
    "print(tlcf_matching.loc[i].cf_abstraction_ids[cf_sample_id])\n",
    "print(abstraction.id_to_label[tlcf_matching.loc[i].cf_abstraction_ids[cf_sample_id]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11\n"
     ]
    }
   ],
   "source": [
    "print(acf_matching.loc[i].observed_abstraction_id.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Also a traditional reporter, this New Jersey-born reporter is among those chosen to represent KQED's public television reporting. Ms. Williams has reported from across the country, from Israel and Moscow, and on eight continents. She has been with the Public Insight Network for two decades. Before joining KQED, Ms. Williams was a senior producer at member station WITN in Washington, DC\n",
      "11\n",
      "journalist\n"
     ]
    }
   ],
   "source": [
    "acf_sample_id = 6  # hardcode for now\n",
    "print(acf_matching.loc[i].cf_samples[acf_sample_id])\n",
    "print(acf_matching.loc[i].cf_abstraction_ids[acf_sample_id])\n",
    "print(abstraction.id_to_label[acf_matching.loc[i].cf_abstraction_ids[acf_sample_id]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With a keen eye and a trigger finger for the political machinations and machinations of big business — no matter how much money is involved — it was clear that it needed more information,\" he wrote.\n",
      "\n",
      "\"At a recent meeting between our staff and Sen. McCain, he had no problem admitting that he had no idea how the process worked,\" he continued. \"That's because he didn't know, and the rest of us didn\n",
      "11\n",
      "journalist\n"
     ]
    }
   ],
   "source": [
    "i = 148\n",
    "print(tlcf_matching.loc[i].factual_samples) \n",
    "print(tlcf_matching.loc[i].observed_abstraction_id.item())\n",
    "print(abstraction.id_to_label[tlcf_matching.loc[i].observed_abstraction_id.item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With a keen eye and a trigger finger for the masses, Lenitcher writes in MATCHMAKER:\n",
      "\n",
      "One of the half-dozen top doctors specializing in visual healing came into my workshop for a session with me in early March. She was speaking in a lab, gesturing and demonstrating low-light imaging to spell out the process of humorless reinvention.\n",
      "\n",
      "\"Wow,\" one of the workshop participants said, \"this\n",
      "19\n",
      "physician\n"
     ]
    }
   ],
   "source": [
    "cf_sample_id = tlcf_interesting.loc[i][2]\n",
    "print(tlcf_matching.loc[i].cf_samples[cf_sample_id])\n",
    "print(tlcf_matching.loc[i].cf_abstraction_ids[cf_sample_id])\n",
    "print(abstraction.id_to_label[tlcf_matching.loc[i].cf_abstraction_ids[cf_sample_id]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11\n"
     ]
    }
   ],
   "source": [
    "print(acf_matching.loc[i].observed_abstraction_id.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "With a keen eye and a trigger finger, it is still easy to make those changes!\n",
      "\n",
      "The next day since it had come outside, I took the time to take some photos and share them with you!\n",
      "\n",
      "I'd also take care of the new addition to the home!\n",
      "\n",
      "As a reminder, the following photos are just one step in my transformation journey. It starts with what\n",
      "18\n",
      "photographer\n"
     ]
    }
   ],
   "source": [
    "acf_sample_id = 4\n",
    "print(acf_matching.loc[i].cf_samples[acf_sample_id])\n",
    "print(acf_matching.loc[i].cf_abstraction_ids[acf_sample_id])\n",
    "print(abstraction.id_to_label[acf_matching.loc[i].cf_abstraction_ids[acf_sample_id]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
