{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "837bc8c4",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ea9e476",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib\n",
    "import os\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as mticker\n",
    "import matplotlib.colors as mcolors\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "# Themeing\n",
    "sns.set_theme(style=\"whitegrid\")\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Computer Modern', 'DejaVu Serif', 'serif'],\n",
    "    'mathtext.fontset': 'cm',\n",
    "    'axes.formatter.use_mathtext': True,\n",
    "})\n",
    "\n",
    "plt.rcParams.update({\n",
    "    # Font Sizes (ICML template uses 10pt)\n",
    "    \"font.size\": 8,\n",
    "    \"axes.titlesize\": 8,\n",
    "    \"axes.labelsize\": 8,\n",
    "    \"legend.fontsize\": 8,\n",
    "    \"xtick.labelsize\": 6,\n",
    "    \"ytick.labelsize\": 6,\n",
    "\n",
    "    \"axes.linewidth\": 0.5,   # Plot Border\n",
    "    \"patch.linewidth\": 0.5,  # Bar Border\n",
    "    \"grid.linewidth\": 0.5,\n",
    "    \"xtick.major.pad\": 0,\n",
    "    \"ytick.major.pad\": 0,\n",
    "    \"xtick.minor.pad\": 0,\n",
    "    \"ytick.minor.pad\": 0,\n",
    "    \"hatch.linewidth\": 0.15,\n",
    "\n",
    "    'legend.borderpad': 0.2,      # Reduce border inside the legend box\n",
    "    'legend.labelspacing': 0.1,   # Reduce vertical spacing between legend entries\n",
    "\n",
    "    \"lines.linewidth\": 1,\n",
    "    \"lines.markersize\": 4,\n",
    "    \"lines.markeredgewidth\": 0.25,\n",
    "    \"lines.markeredgecolor\": \"white\",\n",
    "\n",
    "    \"figure.dpi\": 1500 # 1500, for final plots\n",
    "})\n",
    "\n",
    "# Custom formatter for delta values as percentages with +/- signs\n",
    "def delta_percent_formatter(x, pos):\n",
    "    \"\"\"Format delta values as +/- XX% instead of decimal.\"\"\"\n",
    "    percent = x * 100\n",
    "    if percent == 0:\n",
    "        return \"0%\"\n",
    "    elif percent > 0:\n",
    "        return f\"+{percent:.0f}%\"\n",
    "    else:\n",
    "        return f\"{percent:.0f}%\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1e87694",
   "metadata": {},
   "outputs": [],
   "source": [
    "DOUBLE_COLUMN_WIDTH = 6.75133\n",
    "SINGLE_COLUMN_WIDTH = 3.25063\n",
    "\n",
    "GRID_ALPHA = 0.4\n",
    "ACQUISITION_ORDER = ['Random', 'UltraFeedback', 'MaxMin', 'DeltaQwen', 'DeltaUCB', 'DRTS', 'InfoMax', 'DTS', 'MaxMinLCB']\n",
    "DATASET_ORDER = [\"UltraFeedback\", \"Skywork\", \"Combined\", \"Tulu 3\"]\n",
    "\n",
    "BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2', 'rewardbench_2']\n",
    "DOWNSTREAM_BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']\n",
    "RM_BENCHMARKS = [\"rewardbench_2\"]\n",
    "\n",
    "@dataclass\n",
    "class AcquisitionStyle:\n",
    "    marker: str\n",
    "    hatch: str\n",
    "    color: str\n",
    "    zorder: int\n",
    "    dashes: Tuple[int, ...] | None\n",
    "\n",
    "HATCH_MULTIPLIER = 12\n",
    "ACQUISITION_STYLES = {\n",
    "    'Random': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#a63f3f', dashes=None, zorder=2.1), \n",
    "    'UltraFeedback': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#cb4d4d', dashes=(4, 2), zorder=2.4),\n",
    "    'MaxMin': AcquisitionStyle(marker='^', hatch='\\\\' * HATCH_MULTIPLIER, color='#e06c6c', dashes=(1, 1), zorder=2.7),\n",
    "    'DeltaQwen': AcquisitionStyle(marker='D', hatch='x' * HATCH_MULTIPLIER, color='#ef8f8f', dashes=(4, 2, 1, 2), zorder=2.9),\n",
    "\n",
    "    'DeltaUCB': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3f3fa6', dashes=None, zorder=2.3),\n",
    "    'DRTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4d4dcb', dashes=(4, 2), zorder=2.6),\n",
    "    \n",
    "    'InfoMax': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3fa63f', dashes=None, zorder=2.2),\n",
    "    'DTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4dcb4d', dashes=(4, 2), zorder=2.5),\n",
    "    'MaxMinLCB': AcquisitionStyle(marker='^', hatch='\\\\' * HATCH_MULTIPLIER, color='#6ce06c', dashes=(1, 1), zorder=2.8), \n",
    "    \n",
    "    'Original': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#808080', dashes=(4, 2), zorder=0)\n",
    "}\n",
    "\n",
    "GREEN = \"green\"\n",
    "RED = \"red\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb183d2f",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d31fed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists('full_results.csv'):\n",
    "    print(\"Loaded full results\")\n",
    "    data = pd.read_csv('full_results.csv', sep=',')\n",
    "else:\n",
    "    acquisition_function_mapping = {\n",
    "        \"random\": \"Random\",\n",
    "        \"ultrafeedback\": \"UltraFeedback\",\n",
    "        \"maxmin\": \"MaxMin\",\n",
    "        \"delta_qwen\": \"DeltaQwen\",\n",
    "        \"DeltaUCB\": \"DeltaUCB\",\n",
    "        \"DRTS\": \"DRTS\",\n",
    "        \"InfoMax\": \"InfoMax\",\n",
    "        \"DTS\": \"DTS\",\n",
    "        \"MaxMinLCB\": \"MaxMinLCB\",\n",
    "    }\n",
    "\n",
    "    base_model_scores = {\n",
    "        \"gsm8k\": 0.758,\n",
    "        \"ifeval\": 0.713,\n",
    "        \"truthfulqa\": 0.468,\n",
    "        \"alpacaeval_2\": 0.083,\n",
    "        \"rewardbench_2\": 0.290\n",
    "    }\n",
    "\n",
    "    data = pd.read_csv('results.csv', sep=',')\n",
    "\n",
    "    uf_dpo_sample_efficiency = pd.read_csv(\"ultrafeedback_dpo_sample_efficiency.csv\")\n",
    "    uf_rm_sample_efficiency = pd.read_csv(\"ultrafeedback_rm_sample_efficiency.csv\")\n",
    "\n",
    "    uf_sample_efficiency = pd.merge(\n",
    "        uf_dpo_sample_efficiency,\n",
    "        uf_rm_sample_efficiency,\n",
    "        on='Method',\n",
    "        suffixes=('_dpo', '_rm')\n",
    "    )\n",
    "\n",
    "    uf_sample_efficiency = uf_sample_efficiency[uf_sample_efficiency[\"Method\"] != \"SFT Base Model\"].copy().reset_index(drop=True)\n",
    "\n",
    "    uf_sample_efficiency = uf_sample_efficiency.rename(columns={\n",
    "        'Mean_rm': 'rewardbench_2',\n",
    "        'GSM8K': 'gsm8k',\n",
    "        'IF Eval': 'ifeval',\n",
    "        'Truthful QA': 'truthfulqa',\n",
    "        'Alpaca Eval': 'alpacaeval_2',\n",
    "    })\n",
    "\n",
    "    uf_sample_efficiency['num_train_samples'] = uf_sample_efficiency['Method'].apply(lambda x: int(x.split('_')[-1]))\n",
    "    uf_sample_efficiency['acquisition_function'] = uf_sample_efficiency['Method'].apply(lambda x: acquisition_function_mapping[\"_\".join(x.split('_')[:-1]).split('-')[-1]])\n",
    "    uf_sample_efficiency['po_algorithm'] = \"DPO\"\n",
    "    uf_sample_efficiency['judge'] = \"Qwen 3 235B\"\n",
    "    uf_sample_efficiency['dataset'] = \"UltraFeedback\"\n",
    "\n",
    "    # Add base model scores at num_train_samples = 0 for sample efficiency plots\n",
    "    for acq_name in acquisition_function_mapping.values():\n",
    "        uf_sample_efficiency.loc[len(uf_sample_efficiency)] = {\n",
    "            'dataset': 'UltraFeedback',\n",
    "            'judge': 'Qwen 3 235B',\n",
    "            'acquisition_function': acq_name,\n",
    "            'po_algorithm': 'DPO',\n",
    "            'num_train_samples': 0,\n",
    "            'gsm8k': 0,\n",
    "            'ifeval': 0,\n",
    "            'truthfulqa': 0,\n",
    "            'alpacaeval_2': 0,\n",
    "            'rewardbench_2': 0\n",
    "        }\n",
    "\n",
    "    uf_sample_efficiency = uf_sample_efficiency.drop(columns=['Type_dpo', 'Mean_dpo', 'Type_rm', 'Factuality', 'Focus', 'Math', 'Precise IF', 'Safety', 'Ties', 'Method'])\n",
    "    uf_sample_efficiency = uf_sample_efficiency[data.columns]\n",
    "    acq_order = list(acquisition_function_mapping.values())\n",
    "    uf_sample_efficiency['acq_func_order'] = uf_sample_efficiency['acquisition_function'].apply(lambda x: acq_order.index(x) if x in acq_order else -1)\n",
    "    uf_sample_efficiency = uf_sample_efficiency.sort_values(by=['acq_func_order', 'num_train_samples']).drop(columns=['acq_func_order']).reset_index(drop=True)\n",
    "    uf_sample_efficiency.to_csv(\"ultrafeedback_sample_efficiency.csv\", index=False)\n",
    "\n",
    "\n",
    "    ipo_simpo_sample_efficiency = pd.read_csv(\"final_results4.csv\")\n",
    "    ipo_simpo_sample_efficiency = ipo_simpo_sample_efficiency.rename(columns={\n",
    "        'GSM8K': 'gsm8k',\n",
    "        'IF Eval': 'ifeval',\n",
    "        'Truthful QA': 'truthfulqa',\n",
    "        'Alpaca Eval': 'alpacaeval_2',\n",
    "    })\n",
    "    ipo_simpo_sample_efficiency['rewardbench_2'] = ''\n",
    "    ipo_simpo_sample_efficiency['dataset'] = 'UltraFeedback'\n",
    "    ipo_simpo_sample_efficiency['judge'] = 'Qwen 3 235B'\n",
    "    ipo_simpo_sample_efficiency['num_train_samples'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: int(x.split('_')[-1]))\n",
    "    ipo_simpo_sample_efficiency['po_algorithm'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: x.split('_')[0])\n",
    "    ipo_simpo_sample_efficiency['acquisition_function'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: acquisition_function_mapping['_'.join(x.split('_')[1:-1])])\n",
    "    ipo_simpo_sample_efficiency = ipo_simpo_sample_efficiency.drop(columns=['Type', 'Mean', 'Method'])\n",
    "\n",
    "    data = pd.concat([data, ipo_simpo_sample_efficiency], ignore_index=True)\n",
    "    data = data.drop_duplicates().reset_index(drop=True)\n",
    "\n",
    "    data = data.assign(\n",
    "        num_train_samples_null=data['num_train_samples'].isna(),\n",
    "        dataset_order_idx=data['dataset'].apply(lambda x: DATASET_ORDER.index(x) if x in DATASET_ORDER else len(DATASET_ORDER)),\n",
    "        acquisition_order_idx=data['acquisition_function'].apply(\n",
    "            lambda x: ACQUISITION_ORDER.index(x) if x in ACQUISITION_ORDER else len(ACQUISITION_ORDER))\n",
    "    ).sort_values(\n",
    "        by=['num_train_samples_null', 'dataset_order_idx', 'po_algorithm', 'acquisition_order_idx', 'num_train_samples'],\n",
    "        ascending=[False, True, True, True, True]\n",
    "    ).drop(columns=['num_train_samples_null', 'dataset_order_idx', 'acquisition_order_idx']).reset_index(drop=True)\n",
    "\n",
    "    data.to_csv(\"full_results.csv\", index=False)\n",
    "\n",
    "if os.path.exists('dataset_statistics.csv'):\n",
    "    print(\"Loaded full results\")\n",
    "    datasets_data = pd.read_csv('dataset_statistics.csv', sep=',')\n",
    "else:\n",
    "    datasets_data = pd.read_csv('my_analysis.csv', sep=',')\n",
    "\n",
    "    datasets_data['training'] = datasets_data['Dataset_Name'].apply(lambda x: x.split('_')[0] if isinstance(x, str) and '_' in x else None)\n",
    "    datasets_data['acquisition_function'] = datasets_data['Dataset_Name'].apply(lambda x: x.split('_')[1] if isinstance(x, str) and '_' in x else None)\n",
    "\n",
    "    datasets_data.drop(columns=['Dataset_Name'], inplace=True)\n",
    "    datasets_data.rename(columns={'Model': 'model'}, inplace=True)\n",
    "\n",
    "    datasets_data = datasets_data[['training', 'acquisition_function', 'model', 'chosen_count', 'rejected_count']]\n",
    "\n",
    "    datasets_data.to_csv(\"dataset_statistics.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c37979a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"rm_mean_score\"] = data[RM_BENCHMARKS].mean(axis=1)\n",
    "data[\"downstream_mean_score\"] = data[DOWNSTREAM_BENCHMARKS].mean(axis=1)\n",
    "\n",
    "po_algo_ablation_raw_data = data[(data['dataset'] == 'UltraFeedback') & (data['num_train_samples'].isna())].copy()\n",
    "po_algo_ablation_raw_data.drop(columns=['rewardbench_2'], inplace=True)\n",
    "\n",
    "dataset_ablation_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]\n",
    "\n",
    "teaser_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]\n",
    "sample_efficiency_raw_data = data[(data['dataset'] == 'UltraFeedback') & ((~data['num_train_samples'].isna()) | (data['acquisition_function'] == 'Original')) & (data['po_algorithm'] == 'DPO')].copy()\n",
    "sample_efficiency_ipo_simpo_raw_data = data[\n",
    "    (data['dataset'] == 'UltraFeedback') & \n",
    "    (((~data['num_train_samples'].isna())) | (data['acquisition_function'] == 'Original')) & \n",
    "    ((data['po_algorithm'] == 'IPO') | (data['po_algorithm'] == 'SimPO'))\n",
    "].copy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04874ee7",
   "metadata": {},
   "source": [
    "# Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "598b17e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Dataset Ablation Plot\n",
    "# ==============================================================================\n",
    "\n",
    "dataset_ablation_data = dataset_ablation_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# --- Figure Setup ---\n",
    "fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2))\n",
    "\n",
    "# --- Plot Data ---\n",
    "# Left: downstream scores\n",
    "sns.barplot(\n",
    "    data=dataset_ablation_data,\n",
    "    x='dataset',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    palette=acquisition_colors,\n",
    "    edgecolor=\"white\",\n",
    "    order=DATASET_ORDER,\n",
    "    hue_order=ACQUISITION_ORDER + [\"Original\"],\n",
    "    ax=ax_left\n",
    ")\n",
    "# Right: reward model scores\n",
    "sns.barplot(\n",
    "    data=dataset_ablation_data,\n",
    "    x='dataset',\n",
    "    y='rm_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    palette=acquisition_colors,\n",
    "    order=DATASET_ORDER,\n",
    "    hue_order=ACQUISITION_ORDER + [\"Original\"],\n",
    "    ax=ax_right,\n",
    ")\n",
    "\n",
    "# --- Apply Hatches ---\n",
    "n_hues = len(ACQUISITION_ORDER)\n",
    "n_groups = len(dataset_ablation_data['dataset'].unique())\n",
    "for ax in [ax_left, ax_right]:\n",
    "    for i, bar in enumerate(ax.patches):\n",
    "        hue_idx = i // n_groups\n",
    "        if hue_idx < n_hues:\n",
    "            acq_func = ACQUISITION_ORDER[hue_idx]\n",
    "            bar.set_hatch(acquisition_hatches[acq_func])\n",
    "\n",
    "# --- Legend ---\n",
    "ax_left.get_legend().remove()\n",
    "ax_right.get_legend().remove()\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "for i, handle in enumerate(handles):\n",
    "    if i < len(ACQUISITION_ORDER):\n",
    "        acq_func = ACQUISITION_ORDER[i]\n",
    "        handle.set_hatch(acquisition_hatches[acq_func])\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, 1.1),\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# --- Axis Labels & Titles ---\n",
    "ax_left.set_xlabel('(a) Fine-tuned Models')\n",
    "ax_right.set_xlabel('(b) Reward Models')\n",
    "ax_left.set_ylabel('$\\\\Delta$Score', fontweight=\"bold\")\n",
    "ax_right.set_ylabel('$\\\\Delta$Score', fontweight=\"bold\")\n",
    "\n",
    "# --- Axis Ticks ---\n",
    "for ax in [ax_left, ax_right]:\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_fontweight('bold')\n",
    "        label.set_fontsize(plt.rcParams[\"axes.labelsize\"])\n",
    "        label.set_rotation(20)\n",
    "        label.set_ha('right')\n",
    "        \n",
    "\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])\n",
    "ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])\n",
    "\n",
    "# Format y-ticks as +/- XX%\n",
    "for ax in [ax_left, ax_right]:\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "\n",
    "# --- Grid ---\n",
    "for ax in [ax_left, ax_right]:\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "    ax.grid(axis='x', alpha=0.0)\n",
    "\n",
    "# --- Axis Limits ---\n",
    "ax_left.set_ylim(0, 0.16)\n",
    "ax_right.set_ylim(0, 0.4)\n",
    "\n",
    "# --- Save & Show ---\n",
    "plt.tight_layout()\n",
    "fig.savefig(\"dataset_ablation.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc3f13fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Dataset Ablation Plot (Split Export)\n",
    "# ==============================================================================\n",
    "\n",
    "dataset_ablation_data = dataset_ablation_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# Define single plot width (approx half of double column)\n",
    "SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2\n",
    "HEIGHT = 1.5\n",
    "\n",
    "# ==========================================\n",
    "# 1. Left Plot (Fine-tuned Models)\n",
    "# ==========================================\n",
    "fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "sns.barplot(\n",
    "    data=dataset_ablation_data,\n",
    "    x='dataset',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    palette=acquisition_colors,\n",
    "    edgecolor=\"white\",\n",
    "    order=DATASET_ORDER,\n",
    "    hue_order=ACQUISITION_ORDER + [\"Original\"],\n",
    "    ax=ax_left\n",
    ")\n",
    "ax_left.get_legend().remove() # We export legend separately\n",
    "ax_left.set_xlabel('')\n",
    "ax_left.set_ylabel('$\\\\Delta$Score', fontweight=\"bold\")\n",
    "ax_left.set_ylim(0, 0.16)\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])\n",
    "\n",
    "# ==========================================\n",
    "# 2. Right Plot (Reward Models)\n",
    "# ==========================================\n",
    "fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "sns.barplot(\n",
    "    data=dataset_ablation_data,\n",
    "    x='dataset',\n",
    "    y='rm_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    palette=acquisition_colors,\n",
    "    order=DATASET_ORDER,\n",
    "    hue_order=ACQUISITION_ORDER + [\"Original\"],\n",
    "    ax=ax_right,\n",
    ")\n",
    "ax_right.get_legend().remove()\n",
    "ax_right.set_xlabel('')\n",
    "ax_right.set_ylabel('$\\\\Delta$Score', fontweight=\"bold\")\n",
    "ax_right.set_ylim(0, 0.4)\n",
    "ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 3. Shared Formatting (Hatches & Ticks)\n",
    "# ==========================================\n",
    "n_hues = len(ACQUISITION_ORDER)\n",
    "n_groups = len(dataset_ablation_data['dataset'].unique())\n",
    "\n",
    "for ax in [ax_left, ax_right]:\n",
    "    # Apply Hatches\n",
    "    for i, bar in enumerate(ax.patches):\n",
    "        hue_idx = i // n_groups\n",
    "        if hue_idx < n_hues:\n",
    "            acq_func = ACQUISITION_ORDER[hue_idx]\n",
    "            bar.set_hatch(acquisition_hatches[acq_func])\n",
    "\n",
    "    # Format Ticks\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_fontweight('bold')\n",
    "        label.set_fontsize(plt.rcParams[\"axes.labelsize\"] - 2)\n",
    "        # label.set_rotation(20)\n",
    "        # label.set_ha('right')\n",
    "    \n",
    "    # Format Y-Axis\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "    ax.grid(axis='x', alpha=0.0)\n",
    "\n",
    "# ==========================================\n",
    "# 4. Legend Export\n",
    "# ==========================================\n",
    "# Extract handles from the left plot\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "\n",
    "# Apply hatches to the legend handles to match the plot\n",
    "for i, handle in enumerate(handles):\n",
    "    if i < len(ACQUISITION_ORDER):\n",
    "        acq_func = ACQUISITION_ORDER[i]\n",
    "        handle.set_hatch(acquisition_hatches[acq_func])\n",
    "\n",
    "# Create a dedicated figure just for the legend\n",
    "# Width matches the full double column, height is small\n",
    "fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))\n",
    "\n",
    "fig_leg.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# ==========================================\n",
    "# 5. Save All Files\n",
    "# ==========================================\n",
    "# Save Left Plot\n",
    "fig_left.savefig(\n",
    "    \"dataset_ablation/left.pdf\", \n",
    "    format=\"pdf\", \n",
    "    bbox_inches=\"tight\", \n",
    "    pad_inches=0.02\n",
    ")\n",
    "\n",
    "# Save Right Plot\n",
    "fig_right.savefig(\n",
    "    \"dataset_ablation/right.pdf\", \n",
    "    format=\"pdf\", \n",
    "    bbox_inches=\"tight\", \n",
    "    pad_inches=0.02\n",
    ")\n",
    "\n",
    "# Save Legend\n",
    "fig_leg.savefig(\n",
    "    \"dataset_ablation/legend.pdf\", \n",
    "    format=\"pdf\", \n",
    "    bbox_inches=\"tight\", \n",
    "    pad_inches=0.02\n",
    ")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9987997",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# PO Algorithm Ablation Plot (with broken y-axis)\n",
    "# ==============================================================================\n",
    "\n",
    "po_algo_ablation_data = po_algo_ablation_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# --- Figure Setup ---\n",
    "fig, (ax_top, ax_bottom) = plt.subplots(2, 1, sharex=True, figsize=(SINGLE_COLUMN_WIDTH, 1.5), gridspec_kw={\n",
    "    'height_ratios': [8, 1],\n",
    "    'hspace': 0.1\n",
    "})\n",
    "\n",
    "# --- Plot Data ---\n",
    "for ax in [ax_top, ax_bottom]:\n",
    "    sns.barplot(\n",
    "        data=po_algo_ablation_data,\n",
    "        x='po_algorithm',\n",
    "        y='downstream_mean_score',\n",
    "        hue='acquisition_function',\n",
    "        palette=acquisition_colors,\n",
    "        order=['DPO', 'IPO', 'SimPO'],\n",
    "        hue_order=ACQUISITION_ORDER,\n",
    "        ax=ax\n",
    "    )\n",
    "    ax.get_legend().remove()\n",
    "\n",
    "# --- Apply Hatches ---\n",
    "n_hues = len(ACQUISITION_ORDER)\n",
    "n_groups = len(po_algo_ablation_data['po_algorithm'].unique())\n",
    "for ax in [ax_top, ax_bottom]:\n",
    "    for i, bar in enumerate(ax.patches):\n",
    "        hue_idx = i // n_groups\n",
    "        if hue_idx < n_hues:\n",
    "            acq_func = ACQUISITION_ORDER[hue_idx]\n",
    "            bar.set_hatch(acquisition_hatches[acq_func])\n",
    "\n",
    "# --- Legend ---\n",
    "handles, labels = ax_top.get_legend_handles_labels()\n",
    "for i, handle in enumerate(handles):\n",
    "    if i < len(ACQUISITION_ORDER):\n",
    "        acq_func = ACQUISITION_ORDER[i]\n",
    "        handle.set_hatch(acquisition_hatches[acq_func])\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.45, 1.2),\n",
    "    ncol=3,\n",
    "    frameon=False,\n",
    "    borderpad=0.2,\n",
    "    columnspacing=0.7,\n",
    "    handletextpad=0.3,\n",
    "    handlelength=1.2,\n",
    "    handleheight=0.7,\n",
    ")\n",
    "\n",
    "# --- Axis Labels & Titles ---\n",
    "ax_bottom.set_xlabel('')\n",
    "ax_top.set_ylabel('$\\\\Delta$Score', y=0.4, fontweight=\"bold\")\n",
    "ax_bottom.set_ylabel('')\n",
    "\n",
    "# --- Axis Ticks ---\n",
    "ax_top.set_yticks(np.arange(0, 0.25, 0.05))\n",
    "ax_bottom.set_yticks([-0.25])\n",
    "for label in ax_bottom.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "    label.set_fontsize(plt.rcParams[\"axes.labelsize\"])\n",
    "\n",
    "# Format y-ticks as +/- XX%\n",
    "for ax in [ax_top, ax_bottom]:\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "\n",
    "# --- Grid ---\n",
    "for ax in [ax_top, ax_bottom]:\n",
    "    ax.grid(axis='y', alpha=GRID_ALPHA)\n",
    "    ax.grid(axis='x', alpha=0.0)\n",
    "\n",
    "# --- Axis Limits ---\n",
    "ax_top.set_ylim(-0.01, 0.22)\n",
    "ax_bottom.set_ylim(-0.30, -0.20)\n",
    "\n",
    "# --- Broken Axis Styling ---\n",
    "ax_top.spines['bottom'].set_visible(False)\n",
    "ax_bottom.spines['top'].set_visible(False)\n",
    "ax_top.tick_params(bottom=False)\n",
    "break_kwargs = {\n",
    "    'marker': [(-1, -0.5), (1, 0.5)],\n",
    "    'markersize': 6,\n",
    "    'linestyle': 'none',\n",
    "    'color': plt.rcParams['axes.edgecolor'],\n",
    "    'markeredgecolor': plt.rcParams['axes.edgecolor'],\n",
    "    'mew': plt.rcParams['axes.linewidth'],\n",
    "    'clip_on': False,\n",
    "}\n",
    "ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **break_kwargs)\n",
    "ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **break_kwargs)\n",
    "\n",
    "# --- Save & Show ---\n",
    "# plt.tight_layout()\n",
    "fig.savefig(\"po_algo_ablation.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c7da67",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Sample Efficiency Plot (UltraFeedback)\n",
    "# ==============================================================================\n",
    "\n",
    "sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_dashes = {k: v.dashes if v.dashes is not None else \"\" for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# --- Figure Setup ---\n",
    "fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2.25), sharey=False)\n",
    "\n",
    "# --- Plot Data ---\n",
    "# Left: downstream scores\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_left\n",
    ")\n",
    "# Right: reward model scores\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='rm_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_right\n",
    ")\n",
    "\n",
    "# --- Add reference lines for Original UltraFeedback scores ---\n",
    "original_row = sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] == \"Original\"]\n",
    "\n",
    "for ax, y in [\n",
    "    (ax_left, original_row[\"downstream_mean_score\"].values[0]),\n",
    "    (ax_right, original_row[\"rm_mean_score\"].values[0]),\n",
    "]:\n",
    "    ax.axhline(\n",
    "        y=y,\n",
    "        color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "        dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "        label=\"Original\",\n",
    "        zorder=0,\n",
    "    )\n",
    "\n",
    "# --- Legend ---\n",
    "ax_left.get_legend().remove()\n",
    "ax_right.get_legend().remove()\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, 1.1),\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# --- Axis Labels & Titles ---\n",
    "ax_left.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_right.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_left.text(0.5, -0.275, '(a) Fine-tuned Models', transform=ax_left.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])\n",
    "ax_right.text(0.5, -0.275, '(b) Reward Models', transform=ax_right.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])\n",
    "ax_left.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_right.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])\n",
    "ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])\n",
    "\n",
    "# Format y-ticks as +/- XX%\n",
    "for ax in [ax_left, ax_right]:\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "\n",
    "# Format x-Ticks as '10k', '20k', ... instead of '10000', '20000', ...\n",
    "def thousands_formatter(x, pos):\n",
    "    if x >= 1000:\n",
    "        return f\"{int(x/1000):d}k\"\n",
    "    else:\n",
    "        return f\"{int(x):d}\"\n",
    "\n",
    "for ax in [ax_left, ax_right]:\n",
    "    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))\n",
    "\n",
    "# --- Grid ---\n",
    "ax_left.grid(axis='y', alpha=GRID_ALPHA)\n",
    "ax_right.grid(axis='y', alpha=GRID_ALPHA)\n",
    "\n",
    "# --- Axis Limits ---\n",
    "ax_left.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_left.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_score'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_score'].max() * 1.1\n",
    ")\n",
    "ax_right.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_right.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['rm_mean_score'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['rm_mean_score'].max() * 1.1\n",
    ")\n",
    "\n",
    "# --- Grid ---\n",
    "for ax in [ax_left, ax_right]:\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "\n",
    "# --- Apply Marker Edge Width ---\n",
    "for ax in [ax_left, ax_right]:\n",
    "    for line in ax.get_lines():\n",
    "        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])\n",
    "        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])\n",
    "\n",
    "# --- Robust Z-Order Fix ---\n",
    "color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "for ax in [ax_left, ax_right]:\n",
    "    for line in ax.get_lines():\n",
    "        line_color = mcolors.to_hex(line.get_color()).lower()[:7]\n",
    "        \n",
    "        if line_color in color_map:\n",
    "            line.set_zorder(color_map[line_color])\n",
    "            \n",
    "\n",
    "# --- Save & Show ---\n",
    "plt.tight_layout()\n",
    "fig.savefig(\"sample_efficiency.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8998423",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Sample Efficiency Plot (Split Export)\n",
    "# ==============================================================================\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.ticker as mticker\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_dashes = {k: v.dashes if v.dashes is not None else \"\" for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# Define single plot width\n",
    "SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2\n",
    "HEIGHT = 1.75\n",
    "\n",
    "# Helper function for common formatting\n",
    "def format_efficiency_plot(ax):\n",
    "    # Ticks formatting\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "    ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f\"{int(x/1000):d}k\" if x >= 1000 else f\"{int(x):d}\"))\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "    \n",
    "    # Marker Styles\n",
    "    for line in ax.get_lines():\n",
    "        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])\n",
    "        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])\n",
    "        \n",
    "    # Z-Order Fix\n",
    "    color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "    for line in ax.get_lines():\n",
    "        try:\n",
    "            c = mcolors.to_hex(line.get_color()).lower()[:7]\n",
    "            if c in color_map:\n",
    "                line.set_zorder(color_map[c])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "# ==========================================\n",
    "# 1. Left Plot (Fine-tuned Models)\n",
    "# ==========================================\n",
    "fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "# Plot Lines\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_left\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "original_val_left = sample_efficiency_ultrafeedback_data[\n",
    "    sample_efficiency_ultrafeedback_data[\"acquisition_function\"] == \"Original\"\n",
    "][\"downstream_mean_score\"].values[0]\n",
    "\n",
    "ax_left.axhline(\n",
    "    y=original_val_left,\n",
    "    color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "    dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "    label=\"Original\",\n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "# --- Formatting ---\n",
    "ax_left.get_legend().remove()\n",
    "ax_left.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_left.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15]) \n",
    "\n",
    "# --- Inserted: Left Axis Limits ---\n",
    "ax_left.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_left.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_score'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_score'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_efficiency_plot(ax_left)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 2. Right Plot (Reward Models)\n",
    "# ==========================================\n",
    "fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "# Plot Lines\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='rm_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_right\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "original_val_right = sample_efficiency_ultrafeedback_data[\n",
    "    sample_efficiency_ultrafeedback_data[\"acquisition_function\"] == \"Original\"\n",
    "][\"rm_mean_score\"].values[0]\n",
    "\n",
    "ax_right.axhline(\n",
    "    y=original_val_right,\n",
    "    color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "    dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "    label=\"Original\",\n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "# --- Formatting ---\n",
    "ax_right.get_legend().remove()\n",
    "ax_right.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_right.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4]) \n",
    "\n",
    "# --- Inserted: Right Axis Limits ---\n",
    "ax_right.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_right.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['rm_mean_score'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['rm_mean_score'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_efficiency_plot(ax_right)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 3. Legend Export\n",
    "# ==========================================\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "\n",
    "fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))\n",
    "fig_leg.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# ==========================================\n",
    "# 4. Save\n",
    "# ==========================================\n",
    "fig_left.savefig(\"sample_efficiency/left.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "fig_right.savefig(\"sample_efficiency/right.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "fig_leg.savefig(\"sample_efficiency/legend.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53518bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Sample Efficiency (AlpacaEval Ablation) - Split Export\n",
    "# ==============================================================================\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.ticker as mticker\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()\n",
    "\n",
    "# --- Data Prep ---\n",
    "# Calculate downstream mean WITHOUT AlpacaEval\n",
    "sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'] = \\\n",
    "    sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa']].mean(axis=1)\n",
    "\n",
    "# Calculate downstream mean WITH AlpacaEval\n",
    "sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'] = \\\n",
    "    sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']].mean(axis=1)\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_dashes = {k: v.dashes if v.dashes is not None else \"\" for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2\n",
    "HEIGHT = 1.75\n",
    "\n",
    "# Helper function for common formatting\n",
    "def format_alpaca_plot(ax):\n",
    "    # Format y-ticks as +/- XX%\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "    \n",
    "    # Format x-Ticks as '10k'\n",
    "    def thousands_formatter(x, pos):\n",
    "        return f\"{int(x/1000):d}k\" if x >= 1000 else f\"{int(x):d}\"\n",
    "    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))\n",
    "\n",
    "    # Grid & Markers\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "    for line in ax.get_lines():\n",
    "        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])\n",
    "\n",
    "# ==========================================\n",
    "# 1. Left Plot: WITH AlpacaEval\n",
    "# ==========================================\n",
    "fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_with_alpaca',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_left\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "y_left = sample_efficiency_ultrafeedback_data.loc[\n",
    "    sample_efficiency_ultrafeedback_data[\"acquisition_function\"] == \"Original\", \n",
    "    \"downstream_mean_with_alpaca\"\n",
    "].values[0]\n",
    "\n",
    "ax_left.axhline(\n",
    "    y=y_left,\n",
    "    color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "    dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "    label=\"Original\",\n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "# Formatting\n",
    "ax_left.get_legend().remove()\n",
    "ax_left.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_left.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])\n",
    "\n",
    "# Limits\n",
    "ax_left.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_left.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_alpaca_plot(ax_left)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 2. Right Plot: WITHOUT AlpacaEval\n",
    "# ==========================================\n",
    "fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "sns.lineplot(\n",
    "    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data[\"acquisition_function\"] != \"Original\"],\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_no_alpaca',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_right\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "y_right = sample_efficiency_ultrafeedback_data.loc[\n",
    "    sample_efficiency_ultrafeedback_data[\"acquisition_function\"] == \"Original\", \n",
    "    \"downstream_mean_no_alpaca\"\n",
    "].values[0]\n",
    "\n",
    "ax_right.axhline(\n",
    "    y=y_right,\n",
    "    color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "    dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "    label=\"Original\",\n",
    "    zorder=0,\n",
    ")\n",
    "\n",
    "# Formatting\n",
    "ax_right.get_legend().remove()\n",
    "ax_right.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_right.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_right.set_yticks([0.00, 0.05, 0.10, 0.15])\n",
    "\n",
    "# Limits\n",
    "ax_right.set_xlim(\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_right.set_ylim(\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].min() * 1.1,\n",
    "    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_alpaca_plot(ax_right)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 3. Legend Export\n",
    "# ==========================================\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "\n",
    "fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))\n",
    "fig_leg.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 4. Save\n",
    "# ==========================================\n",
    "fig_left.savefig(\"sample_efficiency_no_alpaca_eval/left.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0)\n",
    "fig_right.savefig(\"sample_efficiency_no_alpaca_eval/right.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0)\n",
    "fig_leg.savefig(\"sample_efficiency_no_alpaca_eval/legend.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "491238be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Teaser Radar Plot (Normalized)\n",
    "# ==============================================================================\n",
    "\n",
    "teaser_data = teaser_raw_data.copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# Legend display names with symbols\n",
    "legend_names = {\n",
    "    'DeltaUCB': r'DeltaUCB$^\\dagger$',\n",
    "    'DRTS': r'DRTS$^\\dagger$',\n",
    "    'DTS': r'DTS$^*$',\n",
    "}\n",
    "\n",
    "# --- Data Preparation ---\n",
    "benchmark_to_label = {\n",
    "    \"truthfulqa\": \"TruthfulQA\",\n",
    "    \"gsm8k\": \"GSM8K\",\n",
    "    \"rewardbench_2\": \"RewardBench 2\",\n",
    "    \"ifeval\": \"IFEval\",\n",
    "    \"alpacaeval_2\": \"AlpacaEval 2\"\n",
    "}\n",
    "num_labels = len(benchmark_to_label)\n",
    "\n",
    "# Extract data for each acquisition function and calculate mean over datasets\n",
    "teaser_data = teaser_data[teaser_data[\"acquisition_function\"].isin([\n",
    "    'DeltaUCB',\n",
    "    'UltraFeedback',\n",
    "    'DTS',\n",
    "    'DRTS',\n",
    "    'DeltaQwen',\n",
    "])]\n",
    "teaser_data = (\n",
    "    teaser_data.groupby(['acquisition_function'])[list(benchmark_to_label.keys())]\n",
    "    .mean()\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "benchmark_to_limits = {}\n",
    "for benchmark in benchmark_to_label.keys():\n",
    "    benchmark_max = teaser_data[benchmark].max()\n",
    "    benchmark_to_limits[benchmark] = (0.0, float(benchmark_max))\n",
    "benchmark_to_ticks = {}\n",
    "for benchmark in benchmark_to_label.keys():\n",
    "    benchmark_to_ticks[benchmark] = np.linspace(benchmark_to_limits[benchmark][0], benchmark_to_limits[benchmark][1], 5)[1:]\n",
    "\n",
    "# Normalize now that we have saved the original limits and ticks\n",
    "teaser_data = teaser_data / teaser_data.max() * 0.9\n",
    "\n",
    "y_max = 1.05\n",
    "angle_offset = np.pi / 2 - (2 * np.pi / num_labels) * 1\n",
    "angles = (np.linspace(0, 2 * np.pi, num_labels, endpoint=False) + angle_offset).tolist()\n",
    "angles += angles[:1]\n",
    "\n",
    "# --- Figure Setup ---\n",
    "fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH * 1.1), subplot_kw=dict(polar=True))\n",
    "\n",
    "# --- Plot Data ---\n",
    "for acq_func in ACQUISITION_ORDER:\n",
    "    if acq_func not in teaser_data.index:\n",
    "        continue\n",
    "\n",
    "    values_normalized = teaser_data.loc[acq_func, list(benchmark_to_label.keys())].tolist()\n",
    "    values_closed = values_normalized + [values_normalized[0]]\n",
    "    color = acquisition_colors[acq_func]\n",
    "    marker = acquisition_markers[acq_func]\n",
    "    label = legend_names.get(acq_func, acq_func)  # Use custom name if available\n",
    "\n",
    "    ax.plot(angles, values_closed, color=color, marker=marker, label=label, zorder=10, mew=0)\n",
    "\n",
    "# --- Legend ---\n",
    "ax.legend(\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, -0.1),\n",
    "    ncol=3,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# --- Axis Labels & Titles ---\n",
    "ax.set_xticks(angles[:-1])\n",
    "ax.set_xticklabels([])\n",
    "label_rotations = {\n",
    "    'GSM8K': 0,\n",
    "    'TruthfulQA': -72,\n",
    "    'RewardBench 2': 72,\n",
    "    'IFEval': -36,\n",
    "    'AlpacaEval 2': 36\n",
    "}\n",
    "for angle, label in zip(angles[:-1], benchmark_to_label.values()):\n",
    "    rotation = label_rotations.get(label, 0)\n",
    "    ax.text(angle, 1.25, label, fontweight='bold',\n",
    "            ha='center', va='center', rotation=rotation)\n",
    "\n",
    "# --- Axis Ticks ---\n",
    "# Only show the outermost tick (at 1.0 of the max value) for readability\n",
    "tick_positions_normalized = [1.0]\n",
    "ax.set_yticks(tick_positions_normalized)\n",
    "ax.set_yticklabels([])\n",
    "for angle, benchmark in zip(angles[:-1], benchmark_to_label.keys()):\n",
    "    tick_val = benchmark_to_ticks[benchmark][-1]  # Get the outermost tick value\n",
    "    tick_pos = tick_positions_normalized[0]\n",
    "    # Format as +XX% instead of decimal\n",
    "    tick_percent = tick_val * 100\n",
    "    tick_label = f\"+{tick_percent:.0f}%\" if tick_percent > 0 else f\"{tick_percent:.0f}%\"\n",
    "    # Use radial offset (not vertical) so labels are pushed outward consistently\n",
    "    ax.text(angle, tick_pos + 0.04, tick_label,\n",
    "            color='black', fontsize=plt.rcParams['xtick.labelsize'],\n",
    "            ha='center', va='center', zorder=1)\n",
    "\n",
    "# --- Grid ---\n",
    "ax.yaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)\n",
    "ax.xaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)\n",
    "\n",
    "# --- Axis Limits ---\n",
    "ax.set_ylim(0, y_max)\n",
    "ax.set_rlim(0, y_max)\n",
    "ax.spines['polar'].set_visible(False)\n",
    "\n",
    "# --- Save & Show ---\n",
    "plt.tight_layout()\n",
    "fig.savefig(\"teaser.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197acb41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Dataset Statistics Plot (stacked bars)\n",
    "# ==============================================================================\n",
    "\n",
    "plot_data = datasets_data.copy()\n",
    "\n",
    "dataset_to_data = {\n",
    "    \"baseline_random\": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'random')].copy(),\n",
    "    \"baseline_ultrafeedback\": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'ultrafeedback')].copy(),\n",
    "    \"baseline_maxmin\": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'maxmin')].copy(),\n",
    "    \"dpo_infomax\": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'InfoMax')].copy(),\n",
    "    \"rm_infomax\": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'InfoMax')].copy(),\n",
    "    \"dpo_dts\": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DTS')].copy(),\n",
    "    \"rm_dts\": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DTS')].copy(),\n",
    "    \"dpo_maxminlcb\": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'MaxMinLCB')].copy(),\n",
    "    \"rm_maxminlcb\": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'MaxMinLCB')].copy(),\n",
    "    \"dpo_drts\": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DRTS')].copy(),\n",
    "    \"rm_drts\": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DRTS')].copy(),\n",
    "    \"dpo_deltaucb\": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DeltaUCB')].copy(),\n",
    "    \"rm_deltaucb\": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DeltaUCB')].copy(),\n",
    "}\n",
    "\n",
    "model_name_map = {\n",
    "    \"Qwen/Qwen2.5-0.5B-Instruct\": \"Qwen 2.5 0.5B\",\n",
    "    \"Qwen/Qwen2.5-72B-Instruct\": \"Qwen 2.5 72B\",\n",
    "    \"Qwen/Qwen3-0.6B\": \"Qwen 3 0.6B\",\n",
    "    \"Qwen/Qwen3-1.7B\": \"Qwen 3 1.7B\",\n",
    "    \"Qwen/Qwen3-14B\": \"Qwen 3 14B\",\n",
    "    \"Qwen/Qwen3-30B-A3B\": \"Qwen 3 30B A3B\",\n",
    "    \"Qwen/Qwen3-32B\": \"Qwen 3 32B\",\n",
    "    \"Qwen/Qwen3-235B-A22B\": \"Qwen 3 235B A22B\",\n",
    "    \"meta-llama/Llama-3.1-8B-Instruct\": \"Llama 3.1 8B\",\n",
    "    \"meta-llama/Llama-3.2-1B-Instruct\": \"Llama 3.2 1B\",\n",
    "    \"meta-llama/Llama-3.2-3B-Instruct\": \"Llama 3.2 3B\",\n",
    "    \"meta-llama/Llama-3.3-70B-Instruct\": \"Llama 3.3 70B\",\n",
    "    \"microsoft/Phi-4-mini-instruct\": \"Phi 4 Mini\",\n",
    "    \"microsoft/phi-4\": \"Phi 4\",\n",
    "    \"mistralai/Mistral-Small-24B-Instruct-2501\": \"Mistral Small\",\n",
    "    \"mistralai/Mistral-Large-Instruct-2411\": \"Mistral Large\",\n",
    "    \"nvidia/Llama-3_3-Nemotron-Super-49B-v1\": \"Nemotron Super 49B\",\n",
    "    \"nvidia/Llama-3.1-Nemotron-70B-Instruct-HF\": \"Nemotron 70B\",\n",
    "    \"nvidia/Llama-3_1-Nemotron-Ultra-253B-v1\": \"Nemotron Ultra 253B\",\n",
    "    \"google/gemma-3-1b-it\": \"Gemma 3 1B\",\n",
    "    \"google/gemma-3-4b-it\": \"Gemma 3 4B\",\n",
    "    \"google/gemma-3-12b-it\": \"Gemma 3 12B\",\n",
    "    \"google/gemma-3-27b-it\": \"Gemma 3 27B\",\n",
    "    \"HuggingFaceTB/SmolLM2-1.7B-Instruct\": \"SmolLM2 1.7B\",\n",
    "    \"CohereLabs/c4ai-command-a-03-2025\": \"Command A\",\n",
    "    \"deepseek-ai/DeepSeek-V3\": \"DeepSeek V3\",\n",
    "    \"allenai/OLMo-2-0325-32B-Instruct\": \"OLMo 2 32B\",\n",
    "    \"allenai/Llama-3.1-Tulu-3-70B\": \"Tulu 70B\",\n",
    "    \"allenai/Llama-3.1-Tulu-3-405B\": \"Tulu 405B\",\n",
    "    \"moonshotai/Moonlight-16B-A3B-Instruct\": \"Moonlight 16B A3B\",\n",
    "}\n",
    "\n",
    "# Define which datasets/variants to show and their pretty subplot titles\n",
    "plot_datasets = [\n",
    "    (\"baseline_random\", \"Random\"),\n",
    "    (\"baseline_ultrafeedback\", \"UltraFeedback\"),\n",
    "    (\"baseline_maxmin\", \"MaxMin\"),\n",
    "    (\"dpo_infomax\", \"DPO: InfoMax\"),\n",
    "    (\"rm_infomax\", \"RM: InfoMax\"),\n",
    "    (\"dpo_dts\", \"DPO: DTS\"),\n",
    "    (\"rm_dts\", \"RM: DTS\"),\n",
    "    (\"dpo_maxminlcb\", \"DPO: MaxMinLCB\"),\n",
    "    (\"rm_maxminlcb\", \"RM: MaxMinLCB\"),\n",
    "    (\"dpo_drts\", \"DPO: DRTS\"),\n",
    "    (\"rm_drts\", \"RM: DRTS\"),\n",
    "    (\"dpo_deltaucb\", \"DPO: DeltaUCB\"),\n",
    "    (\"rm_deltaucb\", \"RM: DeltaUCB\"),\n",
    "]\n",
    "\n",
    "# Output each plot individually, filename format: dataset_statistics_{name}.pdf\n",
    "for dataset_key, file_name in plot_datasets:\n",
    "    fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 2.4))\n",
    "\n",
    "    # Prepare data as wide-form for stacking\n",
    "    df = dataset_to_data[dataset_key].copy()\n",
    "\n",
    "    df = df.set_index('model')\n",
    "    df.index = [model_name_map.get(m, m) for m in df.index]\n",
    "\n",
    "    models = list(df.index)\n",
    "    chosen_counts = df['chosen_count']\n",
    "    rejected_counts = df['rejected_count']\n",
    "\n",
    "    # Stacked barplot using native matplotlib\n",
    "    bar1 = ax.bar(\n",
    "        models,\n",
    "        rejected_counts,\n",
    "        color=RED,\n",
    "        label=\"Rejected\",\n",
    "    )\n",
    "    bar2 = ax.bar(\n",
    "        models, \n",
    "        chosen_counts, \n",
    "        color=GREEN, \n",
    "        label=\"Chosen\", \n",
    "        bottom=rejected_counts\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel('')\n",
    "    ax.set_ylabel('Counts')\n",
    "\n",
    "    ax.set_xlim(-0.75, len(models) - 0.25)\n",
    "    ax.set_xticks(range(len(models)))\n",
    "    ax.set_xticklabels(models, rotation=45, ha='right')\n",
    "\n",
    "    # Remove title per instructions\n",
    "    ax.set_title(\"\")\n",
    "\n",
    "    ax.grid(axis='y', alpha=GRID_ALPHA)\n",
    "    ax.grid(axis='x', alpha=0.0)\n",
    "\n",
    "    # Remove legend if it exists, then replace with only correct labels\n",
    "    legend = ax.get_legend()\n",
    "    if legend is not None:\n",
    "        legend.remove()\n",
    "\n",
    "    # Tight layout and save\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(f\"dataset_statistics/{dataset_key}.pdf\", format=\"pdf\", bbox_inches=\"tight\")\n",
    "    plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c6f324c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ==============================================================================\n",
    "# Sample Efficiency Plot - IPO, SimPO (Split Export)\n",
    "# ==============================================================================\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.ticker as mticker\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "ipo_data = sample_efficiency_ipo_simpo_raw_data[(sample_efficiency_ipo_simpo_raw_data[\"po_algorithm\"] == \"IPO\") & (sample_efficiency_ipo_simpo_raw_data[\"acquisition_function\"] != \"Original\")].copy()\n",
    "simpo_data = sample_efficiency_ipo_simpo_raw_data[(sample_efficiency_ipo_simpo_raw_data[\"po_algorithm\"] == \"SimPO\") & (sample_efficiency_ipo_simpo_raw_data[\"acquisition_function\"] != \"Original\")].copy()\n",
    "\n",
    "# --- Style Setup ---\n",
    "acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_dashes = {k: v.dashes if v.dashes is not None else \"\" for k, v in ACQUISITION_STYLES.items()}\n",
    "acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "\n",
    "# Define single plot width\n",
    "SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2\n",
    "HEIGHT = 1.75\n",
    "\n",
    "# Helper function for common formatting\n",
    "def format_efficiency_plot(ax):\n",
    "    # Ticks formatting\n",
    "    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))\n",
    "    ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f\"{int(x/1000):d}k\" if x >= 1000 else f\"{int(x):d}\"))\n",
    "    \n",
    "    # Grid\n",
    "    ax.grid(alpha=GRID_ALPHA)\n",
    "    \n",
    "    # Marker Styles\n",
    "    for line in ax.get_lines():\n",
    "        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])\n",
    "        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])\n",
    "        \n",
    "    # Z-Order Fix\n",
    "    color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}\n",
    "    for line in ax.get_lines():\n",
    "        try:\n",
    "            c = mcolors.to_hex(line.get_color()).lower()[:7]\n",
    "            if c in color_map:\n",
    "                line.set_zorder(color_map[c])\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "# ==========================================\n",
    "# 1. Left Plot (IPO)\n",
    "# ==========================================\n",
    "fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "# Plot Lines\n",
    "sns.lineplot(\n",
    "    data=ipo_data,\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_left\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "# original_val_left = ipo_data[\n",
    "#     ipo_data[\"acquisition_function\"] == \"Original\"\n",
    "# ][\"downstream_mean_score\"].values[0]\n",
    "\n",
    "# ax_left.axhline(\n",
    "#     y=original_val_left,\n",
    "#     color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "#     dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "#     label=\"Original\",\n",
    "#     zorder=0,\n",
    "# )\n",
    "\n",
    "# --- Formatting ---\n",
    "ax_left.get_legend().remove()\n",
    "ax_left.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_left.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_left.set_yticks([0.00, 0.05, 0.10, 0.15]) \n",
    "\n",
    "# --- Inserted: Left Axis Limits ---\n",
    "ax_left.set_xlim(\n",
    "    ipo_data['num_train_samples'].min() * 1.1,\n",
    "    ipo_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_left.set_ylim(\n",
    "    # ipo_data['downstream_mean_score'].min() * 1.1,\n",
    "    -0.051,\n",
    "    ipo_data['downstream_mean_score'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_efficiency_plot(ax_left)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 2. Right Plot (SimPO)\n",
    "# ==========================================\n",
    "fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))\n",
    "\n",
    "# Plot Lines\n",
    "sns.lineplot(\n",
    "    data=simpo_data,\n",
    "    x='num_train_samples',\n",
    "    y='downstream_mean_score',\n",
    "    hue='acquisition_function',\n",
    "    style='acquisition_function',\n",
    "    hue_order=ACQUISITION_ORDER,\n",
    "    style_order=ACQUISITION_ORDER,\n",
    "    palette=acquisition_colors,\n",
    "    markers=acquisition_markers,\n",
    "    dashes=acquisition_dashes,\n",
    "    ax=ax_right\n",
    ")\n",
    "\n",
    "# Reference Line\n",
    "# original_val_right = simpo_data[\n",
    "#     simpo_data[\"acquisition_function\"] == \"Original\"\n",
    "# ][\"rm_mean_score\"].values[0]\n",
    "\n",
    "# ax_right.axhline(\n",
    "#     y=original_val_right,\n",
    "#     color=ACQUISITION_STYLES[\"Original\"].color,\n",
    "#     dashes=ACQUISITION_STYLES[\"Original\"].dashes,\n",
    "#     label=\"Original\",\n",
    "#     zorder=0,\n",
    "# )\n",
    "\n",
    "# --- Formatting ---\n",
    "ax_right.get_legend().remove()\n",
    "ax_right.set_xlabel('Consumed Samples', fontweight=\"bold\")\n",
    "ax_right.set_ylabel('Score $\\\\Delta$', fontweight=\"bold\")\n",
    "ax_right.set_yticks([-0.05, 0.0, 0.1, 0.2, 0.3, 0.4]) \n",
    "\n",
    "# --- Inserted: Right Axis Limits ---\n",
    "ax_right.set_xlim(\n",
    "    simpo_data['num_train_samples'].min() * 1.1,\n",
    "    simpo_data['num_train_samples'].max() * 1.1\n",
    ")\n",
    "ax_right.set_ylim(\n",
    "    simpo_data['downstream_mean_score'].min() * 1.5,\n",
    "    simpo_data['downstream_mean_score'].max() * 1.1\n",
    ")\n",
    "\n",
    "format_efficiency_plot(ax_right)\n",
    "\n",
    "\n",
    "# ==========================================\n",
    "# 3. Legend Export\n",
    "# ==========================================\n",
    "handles, labels = ax_left.get_legend_handles_labels()\n",
    "\n",
    "fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))\n",
    "fig_leg.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='center',\n",
    "    ncol=(len(acquisition_colors) - 1) // 2 + 1,\n",
    "    frameon=False,\n",
    ")\n",
    "\n",
    "# ==========================================\n",
    "# 4. Save\n",
    "# ==========================================\n",
    "fig_left.savefig(\"sample_efficiency_ipo_simpo/left.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "fig_right.savefig(\"sample_efficiency_ipo_simpo/right.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "fig_leg.savefig(\"sample_efficiency_ipo_simpo/legend.pdf\", format=\"pdf\", bbox_inches=\"tight\", pad_inches=0.02)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
