{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e03203fd-e75a-4002-9aa6-8203c496a330",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "import torch\n",
    "from transformer_lens import HookedTransformer\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from sae_lens import SAE\n",
    "import yaml\n",
    "with open('global_config.yaml') as global_stream:\n",
    "    global_cfg = yaml.safe_load(global_stream)\n",
    "CACHE_DIR = global_cfg['CACHE_DIR']\n",
    "torch.set_grad_enabled(False)\n",
    "#splits = ['en-es', 'en-fr', 'de-en']\n",
    "import json\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24896e3d-92c8-4ecb-a02c-ebfd4cc41311",
   "metadata": {},
   "outputs": [],
   "source": [
    "before = [\n",
    "    'bg',\n",
    "    'cs',\n",
    "    'da',\n",
    "    'de',\n",
    "    'el'\n",
    "]\n",
    "\n",
    "after = [ \n",
    "    'es',\n",
    "    'et',\n",
    "    'fi',\n",
    "    'fr',\n",
    "    'hu',\n",
    "    'it',\n",
    "    'lt',\n",
    "    'lv',\n",
    "    'nl',\n",
    "    'pl',\n",
    "    'pt',\n",
    "    'ro',\n",
    "    'sk',\n",
    "    'sl',\n",
    "    'sv'\n",
    "]\n",
    "splits = [f'{other}-en' for other in before] + [f'en-{other}' for other in after]\n",
    "print(len(splits))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d18f9d8-c3a0-4365-82db-1db4913e7119",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Generate Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5246f8a-5e16-410b-a383-f3d181bb5957",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = {}\n",
    "for split in splits:\n",
    "    datasets[split] = load_dataset(\"Helsinki-NLP/europarl\", split, cache_dir=CACHE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb8c4703-1665-4198-8c0c-b1dd32adc474",
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in tqdm(splits):\n",
    "    datasets[split] = datasets[split]['train']['translation'][:250]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72b6d501-09bd-48f1-8258-70e98238939a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/subsetted_europarl_data.json', 'w') as file:\n",
    "    json.dump(datasets, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bbc262c-42e7-4f73-90e5-e96d15272252",
   "metadata": {},
   "source": [
    "# Find diffmean steering vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e807be7-b557-4a34-b8b0-840b1f47386a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/subsetted_europarl_data.json', 'r') as file:\n",
    "    datasets = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3624998e-508b-4ee7-9583-e14b38c77452",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained('gemma-2-2b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.bfloat16)\n",
    "layer = 20\n",
    "#model = HookedTransformer.from_pretrained('gemma-2-9b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.bfloat16)\n",
    "#layer = 33"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c1ce5d4-3ca1-4d37-b0b4-90b59dec8470",
   "metadata": {},
   "outputs": [],
   "source": [
    "# go over all sae vectors.\n",
    "diff_vectors = {}\n",
    "for split in datasets.keys():\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    eng_acts = []\n",
    "    other_acts = []\n",
    "    for data in tqdm(datasets[split][:100]):\n",
    "        english_prompt = data['en']\n",
    "        other_prompt = data[other_lang]\n",
    "        # take away BOS token\n",
    "        english_activation = model(english_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "        other_activation = model(other_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "\n",
    "\n",
    "        eng_acts.append(english_activation.cpu())\n",
    "        other_acts.append(other_activation.cpu())\n",
    "\n",
    "    eng_mean = torch.cat(eng_acts, dim=0).mean(dim=0)\n",
    "    other_mean = torch.cat(other_acts, dim=0).mean(dim=0)\n",
    "    steering_vec = other_mean - eng_mean\n",
    "    steering_vec /= steering_vec.norm(dim=-1)\n",
    "    mean_other_act = (torch.cat(other_acts, dim=0).to('cuda') @ steering_vec.to('cuda')).mean().item()\n",
    "    \n",
    "    diff_vectors[split] = (mean_other_act, steering_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18f5b6b-84a9-4490-af94-25da20567f8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(diff_vectors, 'data/2b_language_steering_vectors_diffmean.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd6ce9d9-1672-4c20-8772-97ea73d39ab7",
   "metadata": {},
   "source": [
    "# Next up, restart kernel and transfer and refit coefficients."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "739a9400-05a9-4881-910d-139cd728810b",
   "metadata": {},
   "source": [
    "## 2b to 9b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f2676b4-1326-4572-aa24-e86e15c6d3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment\n",
    "project_name = f\"stitch_training_gemma-2-2b_to_gemma-2-9b_bidirectional_mse\"\n",
    "checkpoints_dir = os.path.join('checkpoints/', f\"{project_name}/\")\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(2304, 3584, checkpoints_dir, 'fallen-glitter-8', device='cpu', biases=True)\n",
    "diff_vectors = torch.load('data/2b_language_steering_vectors_diffmean.pt', weights_only=True)\n",
    "with open('data/subsetted_europarl_data.json', 'r') as file:\n",
    "    datasets = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06e009bd-5e0d-4737-bfa8-4796dca4982d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model = HookedTransformer.from_pretrained('gemma-2-2b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.bfloat16)\n",
    "model = HookedTransformer.from_pretrained('gemma-2-9b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.bfloat16)\n",
    "#layer = 20\n",
    "layer = 33"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2edc665-fa04-44ba-bdc7-c071561f0795",
   "metadata": {},
   "outputs": [],
   "source": [
    "transferred_diff_vectors = {\n",
    "    k: v[1] @ P.to(torch.bfloat16) for (k,v) in diff_vectors.items()\n",
    "}\n",
    "transferred_diff_vectors = {\n",
    "    k: v / v.norm() for (k,v) in transferred_diff_vectors.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82ab23e1-e2e3-4606-8cd2-90e15ff485cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "for split in datasets.keys():\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    steering_vector = transferred_diff_vectors[split]\n",
    "    others = []\n",
    "    engs = []\n",
    "    for data in tqdm(datasets[split][:100]):\n",
    "        # take away BOS token\n",
    "        other_prompt = data[other_lang]\n",
    "        other_activation = model(other_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "        other_projections = (other_activation @ steering_vector.to('cuda')).cpu()\n",
    "        others.append(other_projections)\n",
    "    #print(torch.cat(engs).mean(), torch.cat(others).mean())\n",
    "    transferred_diff_vectors[split] = (torch.cat(others).mean().item(), steering_vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "891d2160-aa61-4d92-892e-54f64d23c4ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "transferred_diff_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c60e8bb-5882-44b8-a2d1-3412d1ea7017",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(transferred_diff_vectors, 'data/2b_to_9b_transferred_language_steering_vectors_diffmean.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb4f6e61-e2a9-494f-a716-2c1d6c6561eb",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "97c9e84f-f4a3-4e58-94b9-efc2bc966314",
   "metadata": {},
   "source": [
    "# Generate All Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86257e62-2618-4447-bee7-b245e63c738f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from stitching.sae_utils import gemma_generate_with_hooks\n",
    "import functools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64fdf7d0-2148-444a-b8ea-e6fd05c5c2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model_name = 'gemma-2-2b'\n",
    "#layer = 20 # 33 \n",
    "#steering_vectors_file = \"data/2b_language_steering_vectors_diffmean.pt\"\n",
    "#steering_vectors_file = \"data/9b_to_2b_transferred_language_steering_vectors_diffmean.pt\"\n",
    "model_name = 'gemma-2-9b'\n",
    "layer = 33 # 33 \n",
    "steering_vectors_file = \"data/9b_language_steering_vectors_diffmean.pt\"\n",
    "#steering_vectors_file = \"data/9b_to_2b_transferred_language_steering_vectors_with_coefficient.pt\"\n",
    "#instruction = 'en-fr'#'en-es'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62e0bfc6-7644-46bc-8e3a-3733a997cee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def steering_hook(\n",
    "    x,\n",
    "    hook,\n",
    "    steer,\n",
    "    omega\n",
    "):\n",
    "    projected_values = x @ steer.reshape(-1, 1)\n",
    "    x -= projected_values * steer\n",
    "    x += omega * steer\n",
    "    return x\n",
    "\n",
    "@torch.inference_mode()\n",
    "def activation_addition(\n",
    "    x,\n",
    "    hook,\n",
    "    steer,\n",
    "    omega\n",
    "):\n",
    "    x += omega * steer\n",
    "    return x\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_generations(dataset, model, layer, steering_vectors, instruction, max_new_tokens=128, multiplier=1, verbose=False, device='cuda'):\n",
    "    # dataset is just a list of strings\n",
    "    results_dict = []\n",
    "    instruction_tuned = 'it' in model.cfg.model_name\n",
    "    for i, row in tqdm(dataset.iterrows()):\n",
    "        prompt = row['prompt_without_instruction']\n",
    "        if instruction_tuned:\n",
    "            messages_instruction = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "            ]\n",
    "            formatted_prompt = model.tokenizer.apply_chat_template(messages_instruction, tokenize=False, add_generation_prompt=True)\n",
    "        else:\n",
    "            formatted_prompt = f\"Q: {prompt}\\nA:\"\n",
    "        tokens = model.to_tokens(formatted_prompt, prepend_bos=not(instruction_tuned))\n",
    "        \n",
    "        if steering_vectors is None or not(instruction in steering_vectors.keys()):\n",
    "            generation = gemma_generate_with_hooks(model, tokens, max_tokens_generated=max_new_tokens)[0]          \n",
    "        else:\n",
    "            # create hook for the steering\n",
    "            zbar, steering_vector = steering_vectors[instruction]\n",
    "            steering_func = functools.partial(steering_hook, steer=steering_vector.to(device).to(torch_dtype), omega=zbar*multiplier)\n",
    "            #steering_func = functools.partial(activation_addition, steer=steering_vector.to(device).to(torch_dtype), omega=zbar*multiplier)\n",
    "\n",
    "            hook_fn = (f'blocks.{layer}.hook_resid_pre', steering_func)\n",
    "            generation = gemma_generate_with_hooks(model, tokens, max_tokens_generated=max_new_tokens, fwd_hooks=[hook_fn])[0]\n",
    "        if verbose:\n",
    "            print(prompt, '\\n', generation.strip())\n",
    "        results_dict.append({\"prompt\": prompt, \"response\": generation.strip()})\n",
    "    return results_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afff338d-cd50-420f-b7d3-870ec2c7a3a2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_df = pd.read_json(path_or_buf='instruction_following_eval/ifeval_single_instr_format.jsonl', lines=True)\n",
    "if steering_vectors_file is not None:\n",
    "    steering_vectors = torch.load(steering_vectors_file, weights_only=True)\n",
    "else:\n",
    "    steering_vectors = None\n",
    "steering_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b5ed58c-df1c-4948-81e4-28ec9a494ec3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_dtype = torch.bfloat16    \n",
    "print(\"Loading\", model_name)\n",
    "model = HookedTransformer.from_pretrained(model_name=model_name, device='cuda', cache_dir=CACHE_DIR, torch_dtype=torch_dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16ed5ee-0153-4fc4-9895-3a2e3c531a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9512dfb-0dee-4fb4-888c-cfd011f0de52",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for instruction in tqdm(splits):\n",
    "    if os.path.exists(f\"language_steering_full/{instruction}/{model_name}/steered_responses.json\"):\n",
    "        # already computed\n",
    "        continue\n",
    "    \n",
    "    results_dict = get_generations(data_df, model, layer, steering_vectors, instruction, multiplier=1, verbose=False, device='cuda')\n",
    "    os.makedirs(f\"language_steering_full/{instruction}/{model_name}/\", exist_ok=True)\n",
    "    pd.DataFrame(results_dict).to_json(f\"language_steering_full/{instruction}/{model_name}/steered_responses.json\")#steered_responses.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a59a590e-07c5-4fac-af16-f925bc72090d",
   "metadata": {},
   "source": [
    "# Evaluate All Responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4444fd38-6cbf-4a8c-a174-20ef94bc9230",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langdetect import detect\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from stitching.sae_utils import gemma_generate_with_hooks\n",
    "import functools\n",
    "import os\n",
    "from scipy.stats import beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36b39098-7cc4-4fd4-aeaa-d92760165243",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_responses(dataframe):\n",
    "    lol = []\n",
    "    for response in dataframe['response']:\n",
    "        try:\n",
    "            lol.append(detect(response))\n",
    "        except:\n",
    "            lol.append('unk')\n",
    "    return lol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1703134c-eb8d-4551-8733-62e0cdb8d639",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clopper_pearson_exact(x, n=163, alpha=0.05):\n",
    "    lo = beta.ppf(alpha/2,     x,   n-x+1) if x > 0   else 0.0\n",
    "    hi = beta.ppf(1-alpha/2, x+1,   n-x)   if x < n   else 1.0\n",
    "    return lo, hi\n",
    "\n",
    "def get_numbers(split, model):\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    main_dir = f\"language_steering_full/{split}/{model}\"\n",
    "    base_dir = f\"language_steering_full/base/{model}\"\n",
    "    total_queries = len(pd.read_json(os.path.join(base_dir, \"no_steering_responses.json\")))\n",
    "    base_count = Counter(eval_responses(pd.read_json(os.path.join(base_dir, \"no_steering_responses.json\"))))[other_lang]\n",
    "    steered_count = Counter(eval_responses(pd.read_json(os.path.join(main_dir, \"steered_responses.json\"))))[other_lang]\n",
    "    transfer_steered_count = Counter(eval_responses(pd.read_json(os.path.join(main_dir, \"transferred_steered_responses.json\"))))[other_lang]\n",
    "    base_acc =  base_count / total_queries\n",
    "    steered_acc =  steered_count / total_queries\n",
    "    transfer_steer_acc = transfer_steered_count / total_queries\n",
    "    base_err = clopper_pearson_exact(base_count, total_queries)\n",
    "    steered_err = clopper_pearson_exact(steered_count, total_queries)\n",
    "    transfer_steer_err = clopper_pearson_exact(transfer_steered_count, total_queries)\n",
    "    return base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7eb5039-304e-4c81-bab2-b972cf8528cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = f\"language_steering_full/base/gemma-2-2b\"\n",
    "total_queries = len(pd.read_json(os.path.join(base_dir, \"no_steering_responses.json\")))\n",
    "total_queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "508d3d8c-867b-4b66-9f3f-1fd08edf562a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define positions for groups (leave a gap between the two groups)\n",
    "#fig, ax = plt.subplots(1,2, figsize=(15,5))\n",
    "def create_bar(model_name):\n",
    "    global total_queries\n",
    "    xs = []\n",
    "    heights = []\n",
    "    yerrs = []\n",
    "    base_tot_count = 0\n",
    "    steered_tot_count = 0\n",
    "    transfer_tot_count = 0\n",
    "    tot_count = 0\n",
    "    for split in tqdm(splits):\n",
    "        base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err = get_numbers(split, model_name)\n",
    "        base_tot_count += round(base_acc * total_queries)\n",
    "        steered_tot_count += round(steered_acc * total_queries)\n",
    "        transfer_tot_count += round(transfer_steer_acc * total_queries)\n",
    "        tot_count += total_queries\n",
    "        #xs += [f'{split}_no-st', f'{split}_st', f'{split}_tr']#[f'no-steer', f'steer', f'transfer']\n",
    "        #heights += [base_acc, steered_acc, transfer_steer_acc]\n",
    "        #yerrs += [\n",
    "        #    (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "        #    (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "        #    (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "        # ]\n",
    "    base_err = clopper_pearson_exact(base_tot_count, tot_count)\n",
    "    steered_err = clopper_pearson_exact(steered_tot_count, tot_count)\n",
    "    transfer_steer_err = clopper_pearson_exact(transfer_tot_count, tot_count)\n",
    "    \n",
    "    xs = ['no-steer-', 'steer', 'transfer']\n",
    "    base_acc = base_tot_count / tot_count\n",
    "    steered_acc = steered_tot_count / tot_count\n",
    "    transfer_steer_acc = transfer_tot_count / tot_count\n",
    "    heights = [base_acc, steered_acc, transfer_steer_acc]\n",
    "    yerrs = [\n",
    "        (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "        (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "        (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "     ]\n",
    "    return xs, heights, yerrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b0f591-8689-41e3-9f78-d57c7417892f",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelA_name = 'gemma-2-2b'\n",
    "xs_2b, heights_2b, yerrs_2b = create_bar(modelA_name)\n",
    "yerrs_formatted_2b = [[yerr[0] for yerr in yerrs_2b], [yerr[1] for yerr in yerrs_2b]]\n",
    "modelB_name = 'gemma-2-9b'\n",
    "xs_9b, heights_9b, yerrs_9b = create_bar(modelB_name)\n",
    "yerrs_formatted_9b = [[yerr[0] for yerr in yerrs_9b], [yerr[1] for yerr in yerrs_9b]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba825b12-f5ce-4923-87de-2a751efbadbf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "fig, axes = plt.subplots(1,2, figsize=(5, 4),constrained_layout=True)\n",
    "colors = ['tab:blue'] * 1 + ['tab:red'] * 1 + ['tab:green'] * 1 \n",
    "axes[0].bar(['no_st', 'st', 't-st'], heights_2b, yerr=yerrs_formatted_2b, color=colors)\n",
    "axes[0].set(title = modelA_name, ylabel='% responses in target language')\n",
    "ylim_0 = axes[0].get_ylim()\n",
    "axes[1].bar(['no_steer_9b', 'steer_9b', 'steer_2b->9b'], heights_9b, yerr=yerrs_formatted_9b, color=colors)\n",
    "axes[1].set(title = modelB_name,  ylim=ylim_0)# ylabel='% responses in target language',\n",
    "# hide x‑ticks and labels\n",
    "for ax in axes:\n",
    "    ax.tick_params(axis='x', bottom=False, labelbottom=False)\n",
    "\n",
    "# add legend\n",
    "from matplotlib.patches import Patch\n",
    "legend_handles = [\n",
    "    Patch(facecolor='tab:blue',  label='no_steer'),\n",
    "    Patch(facecolor='tab:red',   label='steer'),\n",
    "    Patch(facecolor='tab:green', label='transfer'),\n",
    "]\n",
    "axes[1].legend(handles=legend_handles, bbox_to_anchor=(1, 1))\n",
    "for ax in axes:\n",
    "    ax.grid(axis='y', linestyle='--', linewidth=0.7)\n",
    "#fig.legend(handles=legend_handles, title='Method', loc='upper center', ncol=3)\n",
    "plt.savefig('results/figures/language_steering_full.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ed8877e-120d-466a-b310-cba80198105d",
   "metadata": {},
   "source": [
    "## Break it down language-wise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea595677-7cd1-4093-a402-664b4f70945b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/lang_freqs.json', 'r') as file:\n",
    "    lang_freqs = json.load(file)\n",
    "tot_sentences = 0\n",
    "for k, v in lang_freqs.items():\n",
    "    tot_sentences += v\n",
    "\n",
    "for k, v in lang_freqs.items():\n",
    "    lang_freqs[k] = v / tot_sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a78250-51ba-499b-b4aa-75ff9d151f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "other_langs = []\n",
    "for split in splits:\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    other_langs.append(other_lang)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0427a2f-bb4a-4fa0-a4b2-0d80c59e1432",
   "metadata": {},
   "source": [
    "#### 2b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa281bf9-6883-44a9-aa87-692fa7553ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define positions for groups (leave a gap between the two groups)\n",
    "#fig, ax = plt.subplots(1,2, figsize=(15,5))\n",
    "\n",
    "model_name = 'gemma-2-2b'\n",
    "xs = []\n",
    "heights = []\n",
    "yerrs = []\n",
    "base_tot_count = 0\n",
    "steered_tot_count = 0\n",
    "transfer_tot_count = 0\n",
    "tot_count = 0\n",
    "acc_diff = []\n",
    "all_accs = {\n",
    "    'no_steer_2b':[], 'no_steer_2b_lo': [], 'no_steer_2b_hi': [],\n",
    "    'steer_2b':[], 'steer_2b_lo': [], 'steer_2b_hi': [],\n",
    "    'steer_9b->2b':[], 'steer_9b->2b_lo':[], 'steer_9b->2b_hi':[]\n",
    "}\n",
    "for split in tqdm(splits):\n",
    "    base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err = get_numbers(split, model_name)\n",
    "    base_tot_count += round(base_acc * total_queries)\n",
    "    steered_tot_count += round(steered_acc * total_queries)\n",
    "    transfer_tot_count += round(transfer_steer_acc * total_queries)\n",
    "    tot_count += total_queries\n",
    "    all_accs['no_steer_2b'].append(base_acc)\n",
    "    all_accs['no_steer_2b_lo'].append(base_err[0])\n",
    "    all_accs['no_steer_2b_hi'].append(base_err[1])\n",
    "    all_accs['steer_2b'].append(steered_acc)\n",
    "    all_accs['steer_2b_lo'].append(steered_err[0])\n",
    "    all_accs['steer_2b_hi'].append(steered_err[1])\n",
    "    all_accs['steer_9b->2b'].append(transfer_steer_acc)\n",
    "    all_accs['steer_9b->2b_lo'].append(transfer_steer_err[0])\n",
    "    all_accs['steer_9b->2b_hi'].append(transfer_steer_err[1])\n",
    "    #xs += [f'{split}_no-st', f'{split}_st', f'{split}_tr']#[f'no-steer', f'steer', f'transfer']\n",
    "    #heights += [base_acc, steered_acc, transfer_steer_acc]\n",
    "    #yerrs += [\n",
    "    #    (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "    #    (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "    #    (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "    #]\n",
    "   # acc_diff.append(steered_acc - transfer_steer_acc)\n",
    "final_dataframe = pd.DataFrame(all_accs, index=splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1906795-51af-4ba7-852d-536a193a8691",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dataframe.to_csv('data/gemma-2-2b_language_steering_full_accuracies.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6dabf5-9679-4374-860a-dcd16d618f05",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "final_dataframe.round(3)*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1140be1-416a-48a2-83f2-071aa481e1b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define which columns to merge into CI strings\n",
    "ci_map = {\n",
    "    'no_steer':      ('no_steer_2b',    'no_steer_2b_lo',    'no_steer_2b_hi'),\n",
    "    'steer':         ('steer_2b',       'steer_2b_lo',       'steer_2b_hi'),\n",
    "    'steer_9b->2b':  ('steer_9b->2b',   'steer_9b->2b_lo',   'steer_9b->2b_hi'),\n",
    "}\n",
    "\n",
    "# build new DataFrame with formatted strings, preserving the index\n",
    "df_ci = pd.DataFrame({\n",
    "    new_col: final_dataframe.apply(\n",
    "        lambda r: f\"{r[mid]:.2f} ({r[lo]:.2f}, {r[hi]:.2f})\",\n",
    "        axis=1\n",
    "    )\n",
    "    for new_col, (mid, lo, hi) in ci_map.items()\n",
    "}, index=final_dataframe.index)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1111869a-b5bc-4aa4-beed-7c7c0c3511a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0472ed52-e699-4ebe-bdfa-dbb663b8f0c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_transfer_gap = final_dataframe['steer_9b->2b'] / final_dataframe['steer_2b'] \n",
    "relative_transfer_gap.index = other_langs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75bd6093-6b75-4e1b-a332-56f49666819a",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_values = []\n",
    "y_values = []\n",
    "for split in relative_transfer_gap.index:\n",
    "    x_values.append(lang_freqs[split])\n",
    "    y_values.append(relative_transfer_gap[split])\n",
    "x_values = np.array(x_values)\n",
    "y_values = np.array(y_values)\n",
    "# preserve as a pandas Series with its original index\n",
    "relative_transfer_gap = (\n",
    "    relative_transfer_gap\n",
    "    .clip(lower=0, upper=1)   # values <0 → 0, >1 → 1 (including ±inf)\n",
    "    .fillna(0)                # NaNs → 0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e951e53-9f13-4128-9e98-bc0444d00dc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig,ax = plt.subplots(2,1, figsize=(4,4), constrained_layout=True)\n",
    "relative_transfer_gap.plot.hist(bins=10, ax=ax[0])\n",
    "ax[0].set(title='9b->2b')\n",
    "relative_transfer_gap_clean.plot.hist(bins=10, ax=ax[1])\n",
    "ax[1].set(xlabel='clipped relative transfer gap', title='2b->9b')\n",
    "plt.savefig('results/figures/clipped_relative_gap_histogram.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6291e31d-716a-4c7d-afae-085403ffbc33",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.scatter(x=np.log10(x_values), y=y_values, marker='^')\n",
    "for i, split in enumerate(relative_transfer_gap.index):\n",
    "    plt.annotate(split, xy=(np.log10(x_values)[i], y_values[i]),\n",
    "                textcoords=\"offset points\",  # interpret xytext as offset\n",
    "        xytext=(2, 2),                # shift label 4 points to the right and up\n",
    "        ha='left',                    # horizontal alignment of text\n",
    "        va='bottom',                  # vertical alignment of text\n",
    "        fontsize=12)\n",
    "plt.xlabel(r'$\\log_{10}$ sentence frequency')\n",
    "plt.ylabel('clipped relative transfer gap')\n",
    "plt.title(r'9b$\\to$2b Transfer vs. Training Frequency')\n",
    "plt.savefig(\"results/figures/9b_2b_relative_transfer_gap_vs_frequency.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b475663d-e9df-497e-91be-8ac6bdfdf6d1",
   "metadata": {},
   "source": [
    "#### 9b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c7f79c-842e-4d0d-a4c8-1ab62c83e19f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define positions for groups (leave a gap between the two groups)\n",
    "#fig, ax = plt.subplots(1,2, figsize=(15,5))\n",
    "\n",
    "model_name = 'gemma-2-9b'\n",
    "xs = []\n",
    "heights = []\n",
    "yerrs = []\n",
    "base_tot_count = 0\n",
    "steered_tot_count = 0\n",
    "transfer_tot_count = 0\n",
    "tot_count = 0\n",
    "acc_diff = []\n",
    "all_accs = {\n",
    "    'no_steer_9b':[], 'no_steer_9b_lo': [], 'no_steer_9b_hi': [],\n",
    "    'steer_9b':[], 'steer_9b_lo': [], 'steer_9b_hi': [],\n",
    "    'steer_2b->9b':[], 'steer_2b->9b_lo':[], 'steer_2b->9b_hi':[]\n",
    "}\n",
    "for split in tqdm(splits):\n",
    "    base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err = get_numbers(split, model_name)\n",
    "    base_tot_count += round(base_acc * total_queries)\n",
    "    steered_tot_count += round(steered_acc * total_queries)\n",
    "    transfer_tot_count += round(transfer_steer_acc * total_queries)\n",
    "    tot_count += total_queries\n",
    "    all_accs['no_steer_9b'].append(base_acc)\n",
    "    all_accs['no_steer_9b_lo'].append(base_err[0])\n",
    "    all_accs['no_steer_9b_hi'].append(base_err[1])\n",
    "    all_accs['steer_9b'].append(steered_acc)\n",
    "    all_accs['steer_9b_lo'].append(steered_err[0])\n",
    "    all_accs['steer_9b_hi'].append(steered_err[1])\n",
    "    all_accs['steer_2b->9b'].append(transfer_steer_acc)\n",
    "    all_accs['steer_2b->9b_lo'].append(transfer_steer_err[0])\n",
    "    all_accs['steer_2b->9b_hi'].append(transfer_steer_err[1])\n",
    "    #xs += [f'{split}_no-st', f'{split}_st', f'{split}_tr']#[f'no-steer', f'steer', f'transfer']\n",
    "    #heights += [base_acc, steered_acc, transfer_steer_acc]\n",
    "    #yerrs += [\n",
    "    #    (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "    #    (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "    #    (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "    #]\n",
    "   # acc_diff.append(steered_acc - transfer_steer_acc)\n",
    "final_dataframe = pd.DataFrame(all_accs, index=splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fdb18f2-8023-472f-8c55-d47246d64595",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dataframe.to_csv('data/gemma-2-9b_language_steering_full_accuracies.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c178dd47-7807-4030-87f9-54524737edf0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "final_dataframe.round(3)*100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fe5a16b-aec7-4095-9580-693c3f836f02",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_dataframe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a514840d-fc09-480b-b90d-99df48bdf1dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define which columns to merge into CI strings\n",
    "ci_map = {\n",
    "    'No Steering (9b)':      ('no_steer_9b',    'no_steer_9b_lo',    'no_steer_9b_hi'),\n",
    "    'Steering (9b)':         ('steer_9b',       'steer_9b_lo',       'steer_9b_hi'),\n",
    "    'Transfer Steering (2b$\\to$9b)':  ('steer_2b->9b',   'steer_2b->9b_lo',   'steer_2b->9b_hi'),\n",
    "}\n",
    "\n",
    "# build new DataFrame with formatted strings, preserving the index\n",
    "df_ci = pd.DataFrame({\n",
    "    new_col: final_dataframe.apply(\n",
    "        lambda r: f\"{r[mid]:.2f} ({r[lo]:.2f}, {r[hi]:.2f})\",\n",
    "        axis=1\n",
    "    )\n",
    "    for new_col, (mid, lo, hi) in ci_map.items()\n",
    "}, index=final_dataframe.index)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "201c02b7-3d36-41c7-a94b-786195b50ab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6e84efb-b03d-46bf-a85f-c514bf8767c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_transfer_gap = final_dataframe['steer_2b->9b'] / final_dataframe['steer_9b'] \n",
    "relative_transfer_gap.index = other_langs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc5593f3-c787-4e79-9894-615b09e4af60",
   "metadata": {},
   "outputs": [],
   "source": [
    "# preserve as a pandas Series with its original index\n",
    "relative_transfer_gap_clean = (\n",
    "    relative_transfer_gap\n",
    "    .clip(lower=0, upper=1)   # values <0 → 0, >1 → 1 (including ±inf)\n",
    "    .fillna(0)                # NaNs → 0\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "416a1408-a723-4279-8bcb-d837629b1fd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_transfer_gap_clean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f36161d-c06c-4aa0-a85d-3515ad8e9952",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_values = []\n",
    "y_values = []\n",
    "for split in relative_transfer_gap.index:\n",
    "    x_values.append(lang_freqs[split])\n",
    "    y_values.append(relative_transfer_gap_clean[split])\n",
    "x_values = np.array(x_values)\n",
    "y_values = np.array(y_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f48cbe57-7611-4379-808f-907ca6d8fdf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = relative_transfer_gap_clean.plot.hist(bins=10)\n",
    "ax.set(xlabel='clipped relative gap', title='2b->9b')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f759e2f4-c800-439e-a449-a6656ba07d18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset, inset_axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f9dbc4b-5a7b-419c-ae15-62efed84eac4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(x=np.log10(x_values), y=y_values, marker='^')\n",
    "\n",
    "for i, split in enumerate(relative_transfer_gap.index):\n",
    "    ax.annotate(split, xy=(np.log10(x_values)[i], y_values[i]),\n",
    "                textcoords=\"offset points\",  # interpret xytext as offset\n",
    "        xytext=(2, 2),                # shift label 4 points to the right and up\n",
    "        ha='left',                    # horizontal alignment of text\n",
    "        va='bottom',                  # vertical alignment of text\n",
    "        fontsize=12)\n",
    "ax.set_xlabel(r'$\\log_{10}$ sentence frequency')\n",
    "ax.set_ylabel('clipped relative transfer gap')\n",
    "ax.set_title(r'2b$\\to$9b Transfer vs. Training Frequency')\n",
    "\n",
    "# 1) create the inset, with a “zoom” factor\n",
    "#axins = inset_axes(ax, width=\"40%\", height=\"40%\", loc='center')  # you can tweak zoom & loc\n",
    "\n",
    "# 2) limit the view to y∈[0,1] (and x to whatever slice you like)\n",
    "#axins.set_ylim(-0.1, 1.1)\n",
    "\n",
    "# 3) redraw the scatter (and any annotations you still want)\n",
    "#x = np.log10(x_values)\n",
    "#y = y_values\n",
    "#labels = other_langs\n",
    "#axins.scatter(x, y, marker='^')\n",
    "#for xi, yi, lab in zip(x, y, labels):\n",
    "#    if 0 <= yi <= 1:\n",
    "#        axins.annotate(lab,\n",
    " #                      (xi, yi),\n",
    " #                      textcoords=\"offset points\",\n",
    " #                      xytext=(2, 2),\n",
    "#                       ha='left', va='bottom',\n",
    "#                       fontsize=7)\n",
    "\n",
    "# 4) draw the little rectangle + connecting lines on the main axes\n",
    "#mark_inset(ax, axins, loc1=2, loc2=1, fc=\"none\", ec=\"0.5\")\n",
    "#axins.tick_params(labelsize=8)   # smaller font in the zoomed inset\n",
    "plt.savefig(\"results/figures/2b_9b_relative_transfer_gap_vs_frequency.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6d19de-6e06-4912-ab9f-634fc13d7508",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "plt.scatter(x=np.log10(x_values), y=y_values, marker='^')\n",
    "for i, split in enumerate(relative_transfer_gap.index):\n",
    "    plt.annotate(split, xy=(np.log10(x_values)[i], y_values[i]),\n",
    "                textcoords=\"offset points\",  # interpret xytext as offset\n",
    "        xytext=(-10, 1),                # shift label 4 points to the right and up\n",
    "        ha='left',                    # horizontal alignment of text\n",
    "        va='bottom',                  # vertical alignment of text\n",
    "        fontsize=12)\n",
    "plt.xlabel(r'$log_{10}$ sentence frequency')\n",
    "plt.ylabel('relative transfer gap')\n",
    "plt.title(r'2b$\\to$9b Transfer vs. Training Frequency')\n",
    "plt.ylim([-0.5, 1.5])\n",
    "#plt.savefig(\"results/figures/9b_2b_relative_transfer_gap_vs_frequency.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a00d9cb-eea3-4d96-bb05-21b615b2a934",
   "metadata": {},
   "source": [
    "#### Horrendous PLots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1090f5-b31f-4d0f-80f3-a34279377f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 5})\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15,5))\n",
    "yerrs_formatted = [[yerr[0] for yerr in yerrs], [yerr[1] for yerr in yerrs]]\n",
    "ax.bar(xs, heights, yerr=yerrs_formatted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3810264a-b9a9-4598-bdb5-99b798ad0da0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define positions for groups (leave a gap between the two groups)\n",
    "#fig, ax = plt.subplots(1,2, figsize=(15,5))\n",
    "model_name = 'gemma-2-9b'\n",
    "\n",
    "xs = []\n",
    "heights = []\n",
    "yerrs = []\n",
    "base_tot_count = 0\n",
    "steered_tot_count = 0\n",
    "transfer_tot_count = 0\n",
    "tot_count = 0\n",
    "acc_diff = []\n",
    "for split in tqdm(splits):\n",
    "    base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err = get_numbers(split, model_name)\n",
    "    base_tot_count += round(base_acc * total_queries)\n",
    "    steered_tot_count += round(steered_acc * total_queries)\n",
    "    transfer_tot_count += round(transfer_steer_acc * total_queries)\n",
    "    tot_count += total_queries\n",
    "    xs += [f'{split}_no-st', f'{split}_st', f'{split}_tr']#[f'no-steer', f'steer', f'transfer']\n",
    "    heights += [base_acc, steered_acc, transfer_steer_acc]\n",
    "    yerrs += [\n",
    "        (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "        (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "        (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "    ]\n",
    "    acc_diff.append(steered_acc - transfer_steer_acc)\n",
    "#base_err = clopper_pearson_exact(base_tot_count, tot_count)\n",
    "#steered_err = clopper_pearson_exact(steered_tot_count, tot_count)\n",
    "#transfer_steer_err = clopper_pearson_exact(transfer_tot_count, tot_count)#\n",
    "#\n",
    "#xs = ['no-steer', 'steer', 'transfer']\n",
    "#base_acc = base_tot_count / tot_count\n",
    "#steered_acc = steered_tot_count / tot_count\n",
    "#transfer_steer_acc = transfer_tot_count / tot_count\n",
    "#heights = [base_acc, steered_acc, transfer_steer_acc]\n",
    "##yerrs = [\n",
    "#    (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "#    (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "#    (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "# ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a760f1-560c-41b1-aa6f-7636cf65e48b",
   "metadata": {},
   "outputs": [],
   "source": [
    "yerrs_formatted = [[yerr[0] for yerr in yerrs], [yerr[1] for yerr in yerrs]]\n",
    "plt.bar(xs, heights, yerr=yerrs_formatted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1afa9401-cfb2-45e2-9b1e-f6b7f3154c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(acc_diff)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efee6ec1-2169-4863-b70f-d2ece6226c1c",
   "metadata": {},
   "source": [
    "# Old experiment below: try with manually chosen SAE features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f81270-2b45-4719-8cdb-6e13c0dad468",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained('gemma-2-2b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.float16)\n",
    "layer = 20\n",
    "#model = HookedTransformer.from_pretrained('gemma-2-9b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.float16)\n",
    "#layer = 33"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78dc574c-c640-43cc-b54e-931fd14bc843",
   "metadata": {},
   "outputs": [],
   "source": [
    "sae, cfg_dict, _ = SAE.from_pretrained(\n",
    "    release = 'gemma-scope-2b-pt-res-canonical', # see other options in sae_lens/pretrained_saes.yaml\n",
    "    #release = 'gemma-scope-9b-pt-res-canonical', # see other options in sae_lens/pretrained_saes.yaml\n",
    "    sae_id = f\"layer_{layer-1}/width_16k/canonical\", # won't always be a hook point\n",
    "    device = 'cpu'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb565ca1-d7fb-42ee-9fe7-0f196a7403a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_mat = sae.W_dec.to('cuda').to(torch.float16)\n",
    "del sae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cda8283b-a555-48f3-a055-a1e331dfbdbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# go over all sae vectors.\n",
    "english_vs_other_projections = {\n",
    "    split: {'english': [], 'other': []} for split in splits\n",
    "}\n",
    "\n",
    "for split in datasets.keys():\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    for data in tqdm(datasets[split]):\n",
    "        english_prompt = data['en']\n",
    "        other_prompt = data[other_lang]\n",
    "        # take away BOS token\n",
    "        english_activation = model(english_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "        other_activation = model(other_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "\n",
    "        english_projections = (english_activation @ decoder_mat.T).cpu()\n",
    "        other_projections = (other_activation @ decoder_mat.T).cpu()\n",
    "        english_vs_other_projections[split]['english'].append(english_projections)\n",
    "        english_vs_other_projections[split]['other'].append(other_projections)\n",
    "    english_vs_other_projections[split]['english'] = torch.cat(english_vs_other_projections[split]['english'], dim=0).mean(dim=0).flatten()\n",
    "    english_vs_other_projections[split]['other'] = torch.cat(english_vs_other_projections[split]['other'], dim=0).mean(dim=0).flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11688541-d505-4c02-a39d-cebf932b84af",
   "metadata": {},
   "outputs": [],
   "source": [
    "(english_vs_other_projections['en-fr']['other'] - english_vs_other_projections['en-fr']['english']).topk(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e27a188c-7261-4c31-aae3-29ca4e4c2a75",
   "metadata": {},
   "outputs": [],
   "source": [
    "from stitching.stitching_utils import open_experiment\n",
    "P, Pinv, beta, bias, biasinv = open_experiment(2304, 3584, 'checkpoints/stitch_training_gemma-2-2b_to_gemma-2-9b_bidirectional_mse', 'fallen-glitter-8', device='cpu', biases=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c243e3-3e70-4cff-bd10-f23a3466b445",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_indices = {\n",
    "    'es': 10845,\n",
    "    'fr': 6153,\n",
    "    'de': 3826,\n",
    "}\n",
    "feature_vectors = {\n",
    "    f'{k}_decoder': decoder_mat[idx].cpu() for k, idx in feature_indices.items()\n",
    "}\n",
    "transferred_feature_vectors = {\n",
    "    k: v @ Pinv.to(torch.float16) if 'decoder' in k else v for (k,v) in feature_vectors.items()\n",
    "}\n",
    "transferred_feature_vectors = {\n",
    "    k: v / v.norm() if 'decoder' in k else v for (k,v) in transferred_feature_vectors.items()\n",
    "}\n",
    "#for k, idx in feature_indices.items():\n",
    "#    feature_vectors[f'{k}_encoder'] = sae_A.W_enc[idx].cpu()\n",
    "#    feature_vectors[f'{k}_bias'] = sae_A.b_enc[idx].item()\n",
    "torch.save(feature_vectors, 'data/9b_language_steering_vectors.pt')\n",
    "torch.save(transferred_feature_vectors, 'data/9b_to_2b_transferred_language_steering_vectors.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adeb982b-0070-431a-b2da-3cd5c48d7d13",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_indices = {\n",
    "    'es': 5750,\n",
    "    'fr': 12962,\n",
    "    'de': 9404,\n",
    "}\n",
    "feature_vectors = {\n",
    "    f'{k}_decoder': decoder_mat[idx].cpu() for k, idx in feature_indices.items()\n",
    "}\n",
    "transferred_feature_vectors = {\n",
    "    k: v @ P.to(torch.float16) if 'decoder' in k else v for (k,v) in feature_vectors.items()\n",
    "}\n",
    "transferred_feature_vectors = {\n",
    "    k: v / v.norm() if 'decoder' in k else v for (k,v) in transferred_feature_vectors.items()\n",
    "}\n",
    "#for k, idx in feature_indices.items():\n",
    "#    feature_vectors[f'{k}_encoder'] = sae_A.W_enc[idx].cpu()\n",
    "#    feature_vectors[f'{k}_bias'] = sae_A.b_enc[idx].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d058e1-226b-4064-b2fe-ef9e08219a73",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(feature_vectors, 'data/language_steering_vectors.pt')\n",
    "torch.save(transferred_feature_vectors, 'data/transferred_language_steering_vectors.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efc51e14-8212-41dc-9440-6c55cf7a4313",
   "metadata": {},
   "source": [
    "# skip this for now.\n",
    "what we really want to evaluate is how the steering works, so we really want to fit the coefficient."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f16b2bb-91bf-4ba2-b6ad-830ea858dd1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "steering_vectors = torch.load('data/2b_language_steering_vectors_diffmean.pt', weights_only=True)\n",
    "#steering_vectors = torch.load('data/9b_to_2b_transferred_language_steering_vectors.pt', weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1befd0cf-da85-4963-9e91-4f30a7d9baa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "steering_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca2d4cbe-44f9-4d22-b539-93fc1d52dea6",
   "metadata": {},
   "outputs": [],
   "source": [
    "steering_vectors = {k : steering_vector.to(torch.bfloat16) if 'decoder' in k else steering_vector for k, steering_vector in steering_vectors.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddfb5248-93fc-4758-9edf-d87eb4cd6125",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HookedTransformer.from_pretrained('gemma-2-2b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.bfloat16)\n",
    "#model = HookedTransformer.from_pretrained('gemma-2-9b', cache_dir=CACHE_DIR, device='cuda', torch_dtype=torch.float16)\n",
    "layer = 20\n",
    "#layer = 33"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3704a409-2ee3-4125-a01d-eb06e942b600",
   "metadata": {},
   "outputs": [],
   "source": [
    "english_vs_other_projections = {\n",
    "    split: {'english': [], 'other': []} for split in splits\n",
    "}\n",
    "\n",
    "for split in datasets.keys():\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    steering_vector = steering_vectors[f\"{other_lang}_decoder\"]\n",
    "    for data in tqdm(datasets[split]):\n",
    "        english_prompt = data['en']\n",
    "        other_prompt = data[other_lang]\n",
    "        # take away BOS token\n",
    "        english_activation = model(english_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "        other_activation = model(other_prompt, prepend_bos=True, stop_at_layer=layer)[0, 1:]\n",
    "\n",
    "        english_projections = (english_activation @ steering_vector.to('cuda')).cpu()\n",
    "        other_projections = (other_activation @ steering_vector.to('cuda')).cpu()\n",
    "        english_vs_other_projections[split]['english'].append(english_projections)\n",
    "        english_vs_other_projections[split]['other'].append(other_projections)\n",
    "    #english_vs_other_projections[split]['english'] = torch.cat(english_vs_other_projections[split]['english']).max()\n",
    "    #english_vs_other_projections[split]['other'] = torch.cat(english_vs_other_projections[split]['other']).max()\n",
    "    english_vs_other_projections[split]['english'] = torch.cat(english_vs_other_projections[split]['english']).mean()\n",
    "    english_vs_other_projections[split]['other'] = torch.cat(english_vs_other_projections[split]['other']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "677b9d64-6b0d-478e-8256-70c3adacba6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "english_vs_other_projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ea977bb-a530-4514-bb99-dfd8fa4ccb48",
   "metadata": {},
   "outputs": [],
   "source": [
    "# now format\n",
    "final_steering_vectors = {k: None for k in splits}\n",
    "multiplier = 1\n",
    "for split in splits:\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    final_steering_vectors[split] = (english_vs_other_projections[split]['other'] * multiplier, steering_vectors[f\"{other_lang}_decoder\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "627fa223-75ed-45e9-92ff-73302216b2d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(final_steering_vectors, \"data/2b_language_steering_vectors_diffmean_with_coefficient.pt\")\n",
    "#torch.save(final_steering_vectors, \"data/9b_to_2b_transferred_language_steering_vectors_with_coefficient.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81a0f97a-6e89-4a7c-920b-df615758ba93",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Generate responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a0119e1-132c-4049-a59d-3bb5517318c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from stitching.sae_utils import gemma_generate_with_hooks\n",
    "import functools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "892c4b64-4d03-4d8d-9ada-2e98f99d49e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'gemma-2-2b'\n",
    "layer = 20 # 33 \n",
    "steering_vectors_file = \"data/2b_language_steering_vectors_diffmean_with_coefficient.pt\"\n",
    "#model_name = 'gemma-2-9b'\n",
    "#layer = 33 # 33 \n",
    "#steering_vectors_file = \"data/transferred_language_steering_vectors_with_coefficient.pt\"\n",
    "#steering_vectors_file = \"data/9b_to_2b_transferred_language_steering_vectors_with_coefficient.pt\"\n",
    "instruction = 'en-fr'#'en-es'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb1b1b26-35b5-40b0-9857-d0f512fcf3e8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf71a0f-30c1-4cba-a32c-c671aa6a96e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def steering_hook(\n",
    "    x,\n",
    "    hook,\n",
    "    steer,\n",
    "    omega\n",
    "):\n",
    "    projected_values = x @ steer.reshape(-1, 1)\n",
    "    x -= projected_values * steer\n",
    "    x += omega * steer\n",
    "    return x\n",
    "\n",
    "@torch.inference_mode()\n",
    "def activation_addition(\n",
    "    x,\n",
    "    hook,\n",
    "    steer,\n",
    "    omega\n",
    "):\n",
    "    x += omega * steer\n",
    "    return x\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_generations(dataset, model_name, layer, steering_vectors, instruction, max_new_tokens=128, multiplier=1, verbose=False, device='cuda'):\n",
    "    # dataset is just a list of strings\n",
    "    results_dict = []\n",
    "    torch_dtype = torch.float16    \n",
    "    print(\"Loading\", model_name)\n",
    "    model = HookedTransformer.from_pretrained(model_name=model_name, device=device, cache_dir=CACHE_DIR, torch_dtype=torch_dtype)\n",
    "    instruction_tuned = 'it' in model_name\n",
    "    for i, row in tqdm(dataset.iterrows()):\n",
    "        prompt = row['prompt_without_instruction']\n",
    "        if instruction_tuned:\n",
    "            messages_instruction = [\n",
    "                {\"role\": \"user\", \"content\": prompt}\n",
    "            ]\n",
    "            formatted_prompt = model.tokenizer.apply_chat_template(messages_instruction, tokenize=False, add_generation_prompt=True)\n",
    "        else:\n",
    "            formatted_prompt = f\"Q: {prompt}\\nA:\"\n",
    "        tokens = model.to_tokens(formatted_prompt, prepend_bos=not(instruction_tuned))\n",
    "        \n",
    "        if steering_vectors is None or not(instruction in steering_vectors.keys()):\n",
    "            generation = gemma_generate_with_hooks(model, tokens, max_tokens_generated=max_new_tokens)[0]          \n",
    "        else:\n",
    "            # create hook for the steering\n",
    "            zbar, steering_vector = steering_vectors[instruction]\n",
    "            steering_func = functools.partial(steering_hook, steer=steering_vector.to(device).to(torch_dtype), omega=zbar*multiplier)\n",
    "            #steering_func = functools.partial(activation_addition, steer=steering_vector.to(device).to(torch_dtype), omega=zbar*multiplier)\n",
    "\n",
    "            hook_fn = (f'blocks.{layer}.hook_resid_pre', steering_func)\n",
    "            generation = gemma_generate_with_hooks(model, tokens, max_tokens_generated=max_new_tokens, fwd_hooks=[hook_fn])[0]\n",
    "        if verbose:\n",
    "            print(prompt, '\\n', generation.strip())\n",
    "        results_dict.append({\"prompt\": prompt, \"response\": generation.strip()})\n",
    "    return results_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e63bfbd-99ba-4db4-92ac-eaba90cb10b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_df = pd.read_json(path_or_buf='instruction_following_eval/ifeval_single_instr_format.jsonl', lines=True)\n",
    "if steering_vectors_file is not None:\n",
    "    steering_vectors = torch.load(steering_vectors_file, weights_only=True)\n",
    "else:\n",
    "    steering_vectors = None\n",
    "results_dict = get_generations(data_df, model_name, layer, steering_vectors, instruction, multiplier=1, verbose=True, device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b6ece7-29c4-49c9-acd6-24bf944c1f63",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aa909a0-eef2-4d32-a677-a801cadb801d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(f\"language_steering/{instruction}/{model_name}/\", exist_ok=True)\n",
    "\n",
    "pd.DataFrame(results_dict).to_json(f\"language_steering/{instruction}/{model_name}/transferred_steered_responses.json\")#steered_responses.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f0b8abc-bee7-4b21-b14e-65d390e97e12",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Evaluate responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "856e7cc0-93d3-4f39-82db-e886c2df9738",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langdetect import detect\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from stitching.sae_utils import gemma_generate_with_hooks\n",
    "import functools\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458992be-aa12-4587-a5ad-a7d52e314611",
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = ['en-es', 'de-en', 'en-fr']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe212e8-05f7-4990-ac2e-9d06719710d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_responses(dataframe):\n",
    "    lol = []\n",
    "    for response in dataframe['response']:\n",
    "        try:\n",
    "            lol.append(detect(response))\n",
    "        except:\n",
    "            lol.append('unk')\n",
    "    return lol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dab31b4-6971-456f-b077-28d44468520a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import beta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1f4ae5-c828-4e50-8b25-662de06fa0e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clopper_pearson_exact(x, n=163, alpha=0.05):\n",
    "    lo = beta.ppf(alpha/2,     x,   n-x+1) if x > 0   else 0.0\n",
    "    hi = beta.ppf(1-alpha/2, x+1,   n-x)   if x < n   else 1.0\n",
    "    return lo, hi\n",
    "\n",
    "def get_numbers(split, model):\n",
    "    language_list = split.split('-')\n",
    "    other_lang = None\n",
    "    for lang in language_list:\n",
    "        if lang == 'en':\n",
    "            continue\n",
    "        else:\n",
    "            other_lang = lang\n",
    "    main_dir = f\"language_steering/{split}/{model}\"\n",
    "    base_dir = f\"language_steering/en-es/{model}\"\n",
    "    total_queries = len(pd.read_json(os.path.join(base_dir, \"no_steering_responses.json\")))\n",
    "    base_count = Counter(eval_responses(pd.read_json(os.path.join(base_dir, \"no_steering_responses.json\"))))[other_lang]\n",
    "    steered_count = Counter(eval_responses(pd.read_json(os.path.join(main_dir, \"steered_responses.json\"))))[other_lang]\n",
    "    transfer_steered_count = Counter(eval_responses(pd.read_json(os.path.join(main_dir, \"transferred_steered_responses.json\"))))[other_lang]\n",
    "    base_acc =  base_count / total_queries\n",
    "    steered_acc =  steered_count / total_queries\n",
    "    transfer_steer_acc = transfer_steered_count / total_queries\n",
    "    base_err = clopper_pearson_exact(base_count, total_queries)\n",
    "    steered_err = clopper_pearson_exact(steered_count, total_queries)\n",
    "    transfer_steer_err = clopper_pearson_exact(transfer_steered_count, total_queries)\n",
    "    return base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0c223df-0d93-4fab-9fb2-e1f4679c1c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eecb3960-3f3d-49b1-8660-6252b53c8f52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define positions for groups (leave a gap between the two groups)\n",
    "matplotlib.rcParams.update({'font.size': 14})\n",
    "\n",
    "fig, ax = plt.subplots(1,2, figsize=(15,5))\n",
    "for i, model_name in enumerate(['gemma-2-2b', 'gemma-2-9b']):\n",
    "    xs = []\n",
    "    heights = []\n",
    "    yerrs = []\n",
    "    for split in splits:\n",
    "        base_acc, steered_acc, transfer_steer_acc, other_lang, base_err, steered_err, transfer_steer_err = get_numbers(split, model_name)\n",
    "        xs += [f'no-st', f'st', f'tr']#[f'no-steer', f'steer', f'transfer']\n",
    "        heights += [base_acc, steered_acc, transfer_steer_acc]\n",
    "        yerrs += [\n",
    "            (base_acc - base_err[0], base_err[1] - base_acc),\n",
    "            (steered_acc - steered_err[0], steered_err[1] - steered_acc),\n",
    "            (transfer_steer_acc - transfer_steer_err[0], transfer_steer_err[1] - transfer_steer_acc)\n",
    "         ]\n",
    "    group1_x = np.arange(3)  # First group at 0, 1, 2\n",
    "    group2_x = np.arange(3) + 4  # Second group at 4, 5, 6 (gap at 3)\n",
    "    group3_x = np.arange(3) + 8  # Second group at 4, 5, 6 (gap at 3)\n",
    "\n",
    "\n",
    "    yerrs = [[yerr[0] for yerr in yerrs], [yerr[1] for yerr in yerrs]]\n",
    "    # Merge positions and values\n",
    "    x_positions = np.concatenate([group1_x, group2_x, group3_x])\n",
    "    colors = ['tab:blue'] * 3 + ['tab:red'] * 3 + ['tab:green'] * 3  # First group blue, second group red\n",
    "    \n",
    "    bars = ax[i].bar(x_positions, heights, color=colors, yerr=yerrs)\n",
    "    ax[i].set_xticks(x_positions, xs)\n",
    "    ax[i].set_title(model_name)\n",
    "    ax[i].set_ylabel('% responses in target language')\n",
    "    ax[i].legend([bars[0], bars[3], bars[6]], ['en-es', 'en-de', 'en-fr'])\n",
    "plt.savefig(\"results/figures/language_steering_results_w_fr.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43206d2b-5e7b-47eb-9e7f-085247f11070",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
