{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71e31e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "\n",
    "from collections import defaultdict\n",
    "from matplotlib import rcParams\n",
    "\n",
    "rcParams['text.latex.preamble'] = r'\\usepackage{amsfonts}'\n",
    "\n",
    "data = pd.read_csv(\"../data/results/sycophancy_recent.csv\")\n",
    "data = data[data['judge'] == 'GPT-5-mini (medium)']\n",
    "data = data[~data['solver_id'].str.contains('trained') | data['solver_id'].str.contains('best_model')]\n",
    "data['solver'] = np.where(data['solver_id'].str.contains('best_model'), 'Qwen3-4B-FT', data['solver'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df77a152",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "195179be",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('No. final answer:', sum(['matharena' in x for x in data.problem.unique()]))\n",
    "print('No. proof-style:', sum(['matharena' not in x for x in data.problem.unique()]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4d40429",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.groupby('competition').problem.apply(lambda x: len(x.unique()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3c86d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "Counter(data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0]))\n",
    "\n",
    "data['solver'] = data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d6dc048",
   "metadata": {},
   "outputs": [],
   "source": [
    "minimal_model_ids = [\n",
    "    \"openai/o3\", \"openai/o4-mini--high\", \"gemini/gemini-pro-2.5\", \"qwen/qwen3_235b_a22b\"\n",
    "]\n",
    "\n",
    "data[\"mean_score\"] = data[\"incorrect\"].apply(lambda x: np.mean(x) if isinstance(x, list) else x)\n",
    "\n",
    "def get_data_performance(minimal_model_ids, comps=None, split=None):\n",
    "\n",
    "    if comps is None:\n",
    "        data_filtered = data\n",
    "    else:\n",
    "        data_filtered = data[data[\"competition\"].apply(lambda x: x in comps)]\n",
    "    \n",
    "    problem_partitions = dict()\n",
    "\n",
    "    for problem_id in np.unique(data_filtered[\"problem\"]):\n",
    "\n",
    "        df_problem_id = data_filtered[np.logical_and(data_filtered[\"problem\"] == problem_id, \n",
    "                                            data_filtered[\"solver\"].apply(lambda x: \"(\" not in x))]\n",
    "        if all([solver in df_problem_id[\"solver\"].values for solver in minimal_model_ids]):\n",
    "            models_appended = \",\".join(df_problem_id[\"solver\"].sort_values().unique())\n",
    "\n",
    "            if models_appended not in problem_partitions:\n",
    "                problem_partitions[models_appended] = {\n",
    "                    solver: [] for solver in df_problem_id[\"solver\"].unique()\n",
    "                }\n",
    "            for i, row in df_problem_id.iterrows():\n",
    "                problem_partitions[models_appended][row[\"solver\"]].append(row[\"mean_score\"])\n",
    "\n",
    "    n_problems_per_partition = {k: len(v[list(v.keys())[0]]) for k, v in problem_partitions.items()}\n",
    "    problem_partitions = {\n",
    "        partition: {\n",
    "            solver: np.mean(scores) for solver, scores in models.items()\n",
    "        } for partition, models in problem_partitions.items()\n",
    "    }\n",
    "    \n",
    "    return problem_partitions, n_problems_per_partition\n",
    "\n",
    "problem_partitions, n_problems_per_partition = get_data_performance(minimal_model_ids)\n",
    "\n",
    "model_mapper = {\n",
    "    \"DeepSeek-R1-Qwen3-8B\": r\"R1-8B\",\n",
    "    \"gemini-2.5-pro\": r\"Gemini-2.5-Pro\",\n",
    "    \"Gemini Pro 3 Preview\": r\"Gemini-3-Pro\",\n",
    "    \"o4-mini\": r\"o4-mini\",\n",
    "    \"GPT OSS 120B\": r\"OSS-120B\",\n",
    "    \"GPT-5\": r\"GPT-5\",\n",
    "    \"Qwen3-4B\": r\"Qwen3-4B\",\n",
    "    \"Qwen3-235B-A22B\": r\"Qwen3-235B\",\n",
    "    'DeepSeek-v3.1': r\"DS-V3.1\",\n",
    "    'Grok 4': r\"Grok 4\",\n",
    "    'Grok 4 Fast': r\"Grok 4 Fast\",\n",
    "    # \"BrokenMath-Qwen3-4B\": r\"BrokenMath-Qwen3-4B\"\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d6f7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.image as mpimg\n",
    "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
    "\n",
    "def plot(model_scores, column_mapper, save_file=None, label_size=18, shift_factor=0.05,\n",
    "         colors=(\"#cf3b3b\", \"#e7969c\", \"#f2b2b2\"), add_height=0.1, legend=True, logos=True, error_bars=True, errors=None,\n",
    "          error_bar_kws={\"capthick\":0}, add_values=False, values_label_size=18, width=18, legend_fontsize=14):      # ← NEW PARAMETER\n",
    "    fig, ax = plt.subplots(figsize=(width, 6), dpi=300)\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # prepare data & plot\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    df = pd.DataFrame.from_dict(model_scores, orient='index', columns=column_mapper)\n",
    "    df.reset_index(inplace=True)\n",
    "    df.rename(columns={'index': 'Model'}, inplace=True)\n",
    "    df_melted = df.melt(id_vars='Model', var_name='Accuracy Type', value_name='Accuracy Value')\n",
    "\n",
    "    sns.barplot(\n",
    "        data=df_melted, x='Model', y='Accuracy Value', hue='Accuracy Type',\n",
    "        palette=colors[:len(column_mapper)]\n",
    "    )\n",
    "\n",
    "    if error_bars:\n",
    "        if errors is None:\n",
    "            # Calculate errors using the formula if not provided: 2 * sqrt(p * (1 - p) / 504)\n",
    "            df_melted['error'] = df_melted['Accuracy Value'].apply(\n",
    "                lambda p: 2 * np.sqrt(p * (1 - p) / 504) if 0 < p < 1 else 0\n",
    "            )\n",
    "        else:\n",
    "            # Melt the provided errors dictionary and merge it with the scores DataFrame\n",
    "            df_errors = pd.DataFrame.from_dict(errors, orient='index', columns=column_mapper)\n",
    "            df_errors.reset_index(inplace=True)\n",
    "            df_errors.rename(columns={'index': 'Model'}, inplace=True)\n",
    "            df_errors_melted = df_errors.melt(id_vars='Model', var_name='Accuracy Type', value_name='error')\n",
    "            df_melted = pd.merge(df_melted, df_errors_melted, on=['Model', 'Accuracy Type'])\n",
    "\n",
    "    if error_bars:\n",
    "        # Default styling for error bars, can be overridden by user\n",
    "        kws = dict(fmt='none', capsize=4, ecolor='black', elinewidth=1)\n",
    "        if error_bar_kws:\n",
    "            kws.update(error_bar_kws)\n",
    "\n",
    "        # Prepare lists for a single, vectorized ax.errorbar call\n",
    "        x_coords, y_coords, error_vals = [], [], []\n",
    "\n",
    "        # The order of patches from barplot corresponds to the order of rows in the melted dataframe\n",
    "        for patch, (_, row) in zip(ax.patches, df_melted.iterrows()):\n",
    "            height = patch.get_height()\n",
    "            if height > 0:  # Only plot errors for bars that exist\n",
    "                x_coords.append(patch.get_x() + patch.get_width() / 2.)\n",
    "                y_coords.append(height)\n",
    "                error_vals.append(row['error'])\n",
    "        \n",
    "        ax.errorbar(x=x_coords, y=y_coords, yerr=error_vals, **kws)\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # 1. Write “Not applicable” on true-zero bars   (unchanged)\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    for patch in ax.patches:\n",
    "        if patch.get_height() == 0 and patch.get_width() > 0:\n",
    "            x = patch.get_x() + patch.get_width() / 2\n",
    "            x /= len(model_mapper) - 0.6\n",
    "            ax.text(\n",
    "                x, 0.03, \"Not Applicable\",\n",
    "                ha=\"center\", va=\"bottom\", fontsize=16, rotation=90,\n",
    "                transform=ax.get_yaxis_transform()\n",
    "            )\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # 2. **NEW** – percentages above each bar\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    if add_values:\n",
    "        text_offset = 0.01        # gap above the bar (or error bar)\n",
    "        # Iterate through patches and corresponding data rows to position text correctly\n",
    "        for patch, (_, row) in zip(ax.patches, df_melted.iterrows()):\n",
    "            height = patch.get_height()\n",
    "            if height > 0:        # skip the “N/A” zeros\n",
    "                # Position text above the error bar if it exists\n",
    "                y_pos = height\n",
    "                if error_bars:\n",
    "                    y_pos += row['error']\n",
    "\n",
    "                ax.text(\n",
    "                    patch.get_x() + patch.get_width() / 2,   # bar centre\n",
    "                    y_pos + text_offset,                     # just above the bar/error bar\n",
    "                    f'{height*100:.1f}%',                     # convert to %\n",
    "                    ha='center', va='bottom',\n",
    "                    fontsize=values_label_size\n",
    "                )\n",
    "\n",
    "    # rotate x labels\n",
    "    # plt.xticks(rotation=35, ha=\"right\")\n",
    "    if logos:\n",
    "        ax.set_xticklabels([''] * len(ax.get_xticks()))\n",
    "    x_positions = ax.get_xticks()  # The x locations of each bar group\n",
    "    models = df['Model'].unique()\n",
    "    models = [model_mapper.get(model, model) for model in models]\n",
    "\n",
    "    get_png = {\n",
    "        \"DeepSeek-R1-Qwen3-8B\": (\"deepseek.png\", 0.0, 0.00),\n",
    "        \"gemini-2.5-pro\": (\"gemini.png\", -0.02, 0.04),\n",
    "        \"Gemini Pro 3 Preview\": (\"gemini.png\", -0.02, 0.04),\n",
    "        \"o4-mini\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT OSS 120B\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT-5\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"Qwen3-4B\": (\"qwen.png\", 0.04, 0),\n",
    "        \"Qwen3-235B-A22B\": (\"qwen.png\", 0.04, 0),\n",
    "        'DeepSeek-v3.1': (\"deepseek.png\", 0.0, 0.00),\n",
    "        'Grok 4': (\"xai.png\", 0.0, 0),\n",
    "        'Grok 4 Fast': (\"xai.png\", 0.0, 0),\n",
    "        # \"BrokenMath-Qwen3-4B\": (\"brokenmath.png\", 0, 0)\n",
    "    }\n",
    "    get_png = {\n",
    "        model_mapper[k]: v for k, v in get_png.items()\n",
    "    }\n",
    "\n",
    "    if logos:\n",
    "        for x, model_name in zip(x_positions, models):\n",
    "            if model_name in get_png:\n",
    "                # Read the PNG file\n",
    "                arr_img = mpimg.imread(get_png[model_name][0])\n",
    "                # Scale down the image for a better fit\n",
    "                imagebox = OffsetImage(arr_img, zoom=0.03)\n",
    "                ab = AnnotationBbox(\n",
    "                    imagebox, \n",
    "                    (x, -0.17),           # place at x, near 0 (the bottom)\n",
    "                    frameon=False, \n",
    "                    xycoords=('data', 'axes fraction'),  # x in data coords, y in fraction of y-axis\n",
    "                    box_alignment=(0.5, 0)               # anchor the image center at the bottom center\n",
    "                )\n",
    "                \n",
    "                # Add it to the axes\n",
    "                ax.add_artist(ab)\n",
    "            \n",
    "            # Optionally, add text below or next to the image\n",
    "            # For example, rotate the text and place it slightly below the x-axis\n",
    "            ax.text(\n",
    "                x + shift_factor * len(model_name) + get_png[model_name][1],  # shift right from x to avoid overlapping with bar\n",
    "                -0.025,  # shift downward from 0 (in axes fraction) to avoid overlapping with bar\n",
    "                model_name, \n",
    "                ha='right', \n",
    "                va='top', \n",
    "                fontsize=16,\n",
    "                transform=ax.get_xaxis_transform()  # so it stays under the tick, not in data coords\n",
    "            )\n",
    "\n",
    "\n",
    "\n",
    "    # set font size of labels, ticks, and legend\n",
    "    plt.xlabel('')\n",
    "    plt.ylabel('')\n",
    "    plt.tick_params(axis='both', which='major', labelsize=label_size)\n",
    "    if legend:\n",
    "        leg = plt.legend(loc='upper left', fontsize=legend_fontsize)\n",
    "        leg.get_frame().set_alpha(1)\n",
    "    else:\n",
    "        leg = plt.legend([], [], loc='upper left', fontsize=legend_fontsize)\n",
    "        leg.set_visible(False)\n",
    "\n",
    "    plt.tick_params(axis='y', which='both', left=False, right=False)\n",
    "    sns.despine(left=True, bottom=True)\n",
    "    # set background color to grey\n",
    "    ax.set_facecolor((0.97,0.97,0.97))\n",
    "\n",
    "    \n",
    "    # ax.grid(axis='y', linestyle='--', alpha=1)  # Light dashed horizontal lines\n",
    "    # ax.set_axisbelow(True)  # Ensure gridlines are drawn below all other plot elements\n",
    "    max_value = df_melted['Accuracy Value'].max()\n",
    "    ax.set_ylim(0, max_value + add_height)\n",
    "    if max_value < 1:\n",
    "        n = int((max_value + add_height) * 5)  # Scale to 5% increments\n",
    "        ax.set_yticks([i / 5 for i in range(n + 1)])\n",
    "        ax.set_yticklabels([f'${i * 20}\\%$' for i in range(n + 1)])\n",
    "\n",
    "    if save_file:\n",
    "        fig.savefig(save_file, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7081b9b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.groupby('solver').cost.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2bbaa1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: x[1]))\n",
    "sorted_models = list(model_scores.keys())\n",
    "model_scores = {model: model_scores[model] for model in ['GPT-5', 'Gemini Pro 3 Preview', 'Grok 4', 'DeepSeek-v3.1']}\n",
    "plot(model_scores, [\"All\"], save_file=\"results_overview.pdf\", add_values=True, values_label_size=18,width=8, legend=False, shift_factor=0.04)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3bb44b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "125be8b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = {\n",
    "    \"Qwen3-4B\": [0.556, 0.438, 0.446, 0.494]\n",
    "}\n",
    "plot(model_scores, [\"No mitigation\", \"Standard mitigation\", \"Contradiction prompt\", \"Awareness prompt\"], save_file=\"mititgation_results.pdf\", colors=(\"#cf3b3b\", \"#c7d4b3\", \"#a5bb86\", \"#71973f\"), add_values=True, values_label_size=18,width=8, legend=True, shift_factor=0.005,add_height=0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19342e1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "classification_data = pd.read_csv('../data/results/classification.csv')\n",
    "classification_data['problem_id'] = classification_data['problem_id'].str.replace('/', '_')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "504a437c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gettype(x):\n",
    "    df = classification_data[classification_data.problem_id == x]\n",
    "    if len(df) == 0:\n",
    "        print(x)\n",
    "        return None\n",
    "    else:  \n",
    "        return df.iloc[0].types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b758b5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.problem.apply(gettype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc2c09ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "data['types'] = data.problem.apply(gettype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a246be",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "model_errors = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    for problem_type in ['Algebra', 'Combinatorics', 'Number Theory', 'Geometry']:\n",
    "        data_model_partition = data[(data.solver == model) & (data.types.apply(lambda x: problem_type in x))]\n",
    "        model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "        model_errors[model].append(2*(model_scores[model][-1]*(1 - model_scores[model][-1])/len(data_model_partition['mean_score'])**0.5))\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: sum(x[1])))\n",
    "sorted_models = list(model_scores.keys())\n",
    "plot(model_scores, ['Algebra', 'Combinatorics', 'Number Theory', 'Geometry'], errors=model_errors, save_file=\"sycophancy_by_topic.pdf\", add_values=True, values_label_size=10,width=28, legend=True, shift_factor=0.03, colors=[\"#cf3b3b\", \"#c7d4b3\", \"#f7bf5f\", \"#98cdf0\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05af15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.solver_id.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7dba2a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_counts = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    for category in ['incorrect', 'detected', 'corrected', 'correct']:\n",
    "        data_model_partition = data[(data.solver == model) & (data.true_grade == category)]\n",
    "        category_counts[model].append(len(data_model_partition) *504/ len(data[(data.solver == model)]))\n",
    "category_counts = {model:category_counts[model] for model in sorted_models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "603d8a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_mapper={\n",
    "    'incorrect': 'Sycophant',\n",
    "    'correct': 'Ideal',\n",
    "    'detected': 'Detected',\n",
    "    'corrected': 'Corrected'\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be9dd981",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plot(category_counts, [category_mapper[grade] for grade in ['incorrect', 'detected', 'corrected', 'correct']], save_file=\"sycophancy_categories.pdf\", values_label_size=18,width=24, shift_factor=0.04,\n",
    "    colors=(\"#cf3b3b\", \"#c7d4b3\", \"#a5bb86\", \"#71973f\"), error_bars=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160f74fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "utility_data = pd.read_csv(\"../data/results/utility_recent_comps.csv\")\n",
    "utility_data['solver'] = utility_data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])\n",
    "utility_data['problem'] = utility_data['problem'].apply(lambda x: '_'.join(x.split('/')))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32e171d",
   "metadata": {},
   "outputs": [],
   "source": [
    "utility_data['types'] = utility_data.problem.apply(gettype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f481c90b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    utility_data_model_partition = utility_data[(utility_data.solver == model)]\n",
    "    model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "    model_scores[model].append(utility_data_model_partition['accuracy'].mean())\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: x[1]))\n",
    "sorted_models = list(model_scores.keys())\n",
    "plot(model_scores, [\"Sycophancy\", \"Utility\"], save_file=\"full_results.png\", add_values=True, values_label_size=18,width=28,add_height=0.15, legend=True, shift_factor=0.03, legend_fontsize=17)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2600f3c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "utility_data[utility_data.solver == 'GPT-5'].accuracy.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0254b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "model_errors = defaultdict(list)\n",
    "\n",
    "\n",
    "\n",
    "for model in model_mapper:\n",
    "    for problem_type in ['Algebra', 'Combinatorics', 'Number Theory', 'Geometry']:\n",
    "        data_model_partition = utility_data[(utility_data.solver == model) & (utility_data.types.apply(lambda x: problem_type in x))]\n",
    "        model_scores[model].append(data_model_partition['accuracy'].mean())\n",
    "        model_errors[model].append(2*(model_scores[model][-1]*(1 - model_scores[model][-1])/len(data_model_partition['accuracy'])**0.5))\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: sum(x[1])))\n",
    "sorted_models = list(model_scores.keys())\n",
    "plot(model_scores, ['Algebra', 'Combinatorics', 'Number Theory', 'Geometry'], errors=model_errors, save_file=\"utility_by_topic.pdf\", add_values=True, values_label_size=10,width=32, legend=True, shift_factor=0.03, colors=(\"#cf3b3b\", \"#c7d4b3\", \"#f7bf5f\", \"#98cdf0\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b15abaa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_biggest_subset(df, threshold=0.02):\n",
    "    df_ans = df[df['split'].isin(['matharena', 'answer'])]\n",
    "    df_proofs = df[df['split'] == 'proofs'].reset_index(drop=True)  # reset index for clean masking\n",
    "\n",
    "    target_accuracy = df_ans['accuracy'].mean()\n",
    "    total_accuracy = df_proofs['accuracy'].sum()\n",
    "\n",
    "    proofs_mask = np.ones(len(df_proofs), dtype=bool)\n",
    "    current_avg = 0\n",
    "    for i, row in df_proofs.iterrows():\n",
    "        if sum(proofs_mask) == 0:\n",
    "            break\n",
    "        current_avg = total_accuracy / sum(proofs_mask)\n",
    "\n",
    "        if target_accuracy - threshold <= current_avg <= target_accuracy + threshold:\n",
    "            break\n",
    "\n",
    "        if  current_avg > target_accuracy + threshold and row['accuracy'] > target_accuracy + threshold or\\\n",
    "            current_avg < target_accuracy - threshold and row['accuracy'] < target_accuracy - threshold:\n",
    "            proofs_mask[i] = False\n",
    "            total_accuracy -= row['accuracy']\n",
    "\n",
    "    return pd.concat([df_ans, df_proofs[proofs_mask]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0420c452",
   "metadata": {},
   "outputs": [],
   "source": [
    "utility_data_balanced = utility_data.groupby('solver').apply(extract_biggest_subset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2f331b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for solver, group in utility_data.groupby(\"solver\"):\n",
    "    subset = extract_biggest_subset(group)\n",
    "    results.append(subset)\n",
    "utility_data_balanced = pd.concat(results, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d31e897e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc861533",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged = data.merge(\n",
    "    utility_data_balanced[['solver', 'problem']].drop_duplicates(),\n",
    "    on=['solver', 'problem'],\n",
    "    how='left',\n",
    "    indicator='exists'\n",
    ")\n",
    "# Add a boolean column that says whether the pair exists\n",
    "merged['in_extracted'] = merged['exists'] == 'both'\n",
    "merged = merged.drop(columns=['exists'])\n",
    "data_balanced = merged[merged['in_extracted']]\n",
    "\n",
    "partition_final_answer = ['matharena', 'answer']\n",
    "partition_proofs = data_balanced[~data_balanced.competition.isin(partition_final_answer)].competition.unique()\n",
    "\n",
    "model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    for partition in [partition_final_answer, partition_proofs]:\n",
    "        data_model_partition = data_balanced[(data_balanced.solver == model) & (data_balanced.competition.isin(partition))]\n",
    "        model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: x[1][1]))\n",
    "model_scores = {s: model_scores[s] for s in model_scores}\n",
    "plot(model_scores, [\"Final-answer\", \"Proof-style\"], save_file=\"type_comparison.pdf\", add_values=True, values_label_size=15, colors=(\"#cf3b3b\", \"#e7969c\", \"#649ddc\"), shift_factor=0.035, width=26, legend_fontsize=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81f7f988",
   "metadata": {},
   "outputs": [],
   "source": [
    "utility_data_balanced['is_proof'] = utility_data_balanced['split']=='proofs'\n",
    "utility_data_balanced.groupby('is_proof').agg(\n",
    "    count=('accuracy', 'count'),\n",
    "    rate=('accuracy', 'mean')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f9cced",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "for model in model_mapper:\n",
    "    data_model_partition = utility_data[(utility_data.solver == model)]\n",
    "    model_scores[model].append(data_model_partition['accuracy'].mean())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c68679f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr\n",
    "pearsonr([model_scores[model][0] for model in model_scores], [model_scores[model][1] for model in model_scores])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "224bbee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "is_solved = []\n",
    "for _, row in data.iterrows():\n",
    "    relevant_utility = utility_data[(utility_data.solver == row.solver )& (utility_data.problem == row.problem)]\n",
    "    is_solved.append(0 if len(relevant_utility)==0 else relevant_utility.iloc[0].accuracy > 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48dd9e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data['is_solved'] = is_solved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6f8fb94",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data.split=='generic'].groupby('solver').incorrect.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "427f8972",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = data[data.split=='generic'].groupby(['is_solved', 'solver']).agg(\n",
    "    avg_accuracy=('incorrect', 'mean'),\n",
    "    count=('incorrect', 'count')\n",
    ").reset_index()\n",
    "\n",
    "res['err'] = 2*(res.avg_accuracy*(1-res.avg_accuracy)/res['count'])**0.5\n",
    "res.sort_values(by='solver')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f641ef1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    for partition in [False, True]:\n",
    "        data_model_partition = res[(res.solver == model) & (res.is_solved==partition)]\n",
    "        model_scores[model].append(data_model_partition['avg_accuracy'].mean())\n",
    "model_scores = dict(sorted(model_scores.items(), key=lambda x: x[1][1]))\n",
    "model_scores = {s: model_scores[s] for s in model_scores if s in ['GPT-5', 'DeepSeek-v3.1', 'Qwen3-235B-A22B', 'Grok 4', 'Gemini Pro 3 Preview']}\n",
    "plot(model_scores, [\"Unsolved\", \"Solved\"], save_file=\"type_comparison.pdf\", add_values=True, values_label_size=15, colors=(\"#cf3b3b\", \"#e7969c\", \"#649ddc\"), shift_factor=0.035, width=12, legend_fontsize=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ff3ab2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data.split=='generic'].groupby('solver').incorrect.mean().to_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089f5986",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data.split=='generic'].groupby('solver').incorrect.apply(lambda x: 2*(x.mean()*(1-x.mean())/len(x))**0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1e14636",
   "metadata": {},
   "outputs": [],
   "source": [
    "self_data = pd.read_csv('../data/results/sycophancy_more.csv')\n",
    "\n",
    "self_data['solver'] = self_data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])\n",
    "self_data[\"mean_score\"] = self_data[\"incorrect\"].apply(lambda x: np.mean(x) if isinstance(x, list) else x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07f8ef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib.image as mpimg\n",
    "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
    "\n",
    "def plot_stacked(model_scores, column_mapper, save_file=None, label_size=18, shift_factor=0.05,\n",
    "         colors=(\"#cf3b3b\", \"#e7969c\", \"#f2b2b2\"), legend=True, logos=True,\n",
    "         add_values=False, values_label_size=18, width=18):\n",
    "    fig, ax = plt.subplots(figsize=(width, 6), dpi=300)\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # prepare data\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    df = pd.DataFrame.from_dict(model_scores, orient='index', columns=column_mapper)\n",
    "    df.reset_index(inplace=True)\n",
    "    df.rename(columns={'index': 'Model'}, inplace=True)\n",
    "    \n",
    "    bar_width = 0.7\n",
    "    x = range(len(df['Model']))\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # Plot bars\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    for i, model in enumerate(df['Model']):\n",
    "        model_data = df[df['Model'] == model]\n",
    "        col1_val = model_data[list(column_mapper)[0]].values[0]\n",
    "        col2_val = model_data[list(column_mapper)[1]].values[0]\n",
    "\n",
    "\n",
    "        ax.bar(i, col1_val, width=bar_width, color=colors[0], zorder=2)\n",
    "        ax.bar(i, col2_val, width=bar_width, color=colors[1], hatch='//', zorder=3, edgecolor='white')\n",
    "\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # 1. Write “Not applicable” on true-zero bars (unchanged)\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    for patch in ax.patches:\n",
    "        if patch.get_height() == 0 and patch.get_width() > 0:\n",
    "            x_pos = patch.get_x() + patch.get_width() / 2\n",
    "            ax.text(\n",
    "                x_pos, 0.03, \"Not Applicable\",\n",
    "                ha=\"center\", va=\"bottom\", fontsize=16, rotation=90\n",
    "            )\n",
    "\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    # 2. **NEW** – percentages above each bar\n",
    "    # ─────────────────────────────────────────────────────────\n",
    "    if add_values:\n",
    "        text_offset = 0.00\n",
    "        for i, patch in enumerate(ax.patches):\n",
    "            height = patch.get_height()\n",
    "            if i%2==1:\n",
    "                if df.iloc[i//2][column_mapper[1]] + 0.04 > df.iloc[i//2][column_mapper[0]]:\n",
    "                    continue\n",
    "            if height > 0:\n",
    "                ax.text(\n",
    "                    patch.get_x() + patch.get_width() / 2,\n",
    "                    height + text_offset,\n",
    "                    f'{height*100:.1f}%',\n",
    "                    ha='center', va='bottom',\n",
    "                    fontsize=values_label_size,\n",
    "                    color='black' if i%2==0 else 'white'\n",
    "                )\n",
    "\n",
    "    if logos:\n",
    "        ax.set_xticklabels([''] * len(x))\n",
    "    \n",
    "    get_png = {\n",
    "        \"DeepSeek-R1-Qwen3-8B\": (\"deepseek.png\", 0.0, 0.00),\n",
    "        \"gemini-2.5-pro\": (\"gemini.png\", 0.0, 0.04),\n",
    "        \"o4-mini\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT OSS 120B\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT-5\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"Qwen3-4B\": (\"qwen.png\", 0.04, 0),\n",
    "        \"Qwen3-235B-A22B\": (\"qwen.png\", 0.04, 0),\n",
    "        'DeepSeek-v3.1': (\"deepseek.png\", 0.0, 0.00),\n",
    "        'Grok 4': (\"qwen.png\", 0.0, 0)\n",
    "    }\n",
    "\n",
    "    if logos:\n",
    "        for i, model_name in enumerate(df['Model']):\n",
    "            if model_name in get_png:\n",
    "                try:\n",
    "                    arr_img = mpimg.imread(get_png[model_name][0])\n",
    "                    imagebox = OffsetImage(arr_img, zoom=0.03)\n",
    "                    ab = AnnotationBbox(\n",
    "                        imagebox,\n",
    "                        (i, -0.17),\n",
    "                        frameon=False,\n",
    "                        xycoords=('data', 'axes fraction'),\n",
    "                        box_alignment=(0.5, 0)\n",
    "                    )\n",
    "                    ax.add_artist(ab)\n",
    "                except FileNotFoundError:\n",
    "                    print(f\"Logo for {model_name} not found.\")\n",
    "            mapped_name = model_mapper[model_name]\n",
    "            ax.text(\n",
    "                i + shift_factor * len(mapped_name) + (get_png[model_name][1] if model_name in get_png else 0),\n",
    "                -0.025,\n",
    "                mapped_name,\n",
    "                ha='right',\n",
    "                va='top',\n",
    "                fontsize=13,\n",
    "                transform=ax.get_xaxis_transform()\n",
    "            )\n",
    "\n",
    "    plt.xlabel('')\n",
    "    plt.ylabel('')\n",
    "    plt.tick_params(axis='both', which='major', labelsize=label_size)\n",
    "    if legend:\n",
    "        # Create custom legend handles\n",
    "        from matplotlib.patches import Patch\n",
    "        legend_elements = [Patch(facecolor=colors[0], label=list(column_mapper)[0]),\n",
    "                           Patch(facecolor=colors[1], hatch='//', edgecolor='white', label=list(column_mapper)[1])]\n",
    "        leg = plt.legend(handles=legend_elements, loc='upper left', fontsize=18)\n",
    "        leg.get_frame().set_alpha(1)\n",
    "\n",
    "    else:\n",
    "        ax.legend().set_visible(False)\n",
    "\n",
    "    plt.tick_params(axis='y', which='both', left=False, right=False)\n",
    "    sns.despine(left=True, bottom=True)\n",
    "    ax.set_facecolor((0.97, 0.97, 0.97))\n",
    "    \n",
    "    max_value = df[list(column_mapper)].max().max()\n",
    "    ax.set_ylim(0, max_value + 0.1)\n",
    "    n = int((max_value + 0.1) * 5)\n",
    "    ax.set_yticks([i / 5 for i in range(n + 1)])\n",
    "    ax.set_yticklabels([f'{i * 20}%' for i in range(n + 1)])\n",
    "    \n",
    "\n",
    "    if save_file:\n",
    "        fig.savefig(save_file, bbox_inches='tight')\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd31594",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "self_model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    data_model_partition = self_data[(self_data.solver == model)]\n",
    "    if len(data_model_partition) > 0:\n",
    "        self_model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "for model in self_model_scores:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    if len(data_model_partition) > 0:\n",
    "        self_model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "self_model_scores = dict(sorted(self_model_scores.items(), key=lambda x: x[1]))\n",
    "plot_stacked(self_model_scores, [\"Self-Sycophancy\", \"Baseline\"], save_file=\"self_sycophancy.pdf\", add_values=True, values_label_size=16,width=8, legend=True, shift_factor=0.043)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb4f51f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "hint_data = pd.read_csv('../data/results/sycophancy_hint.csv')\n",
    "\n",
    "hint_data['solver'] = hint_data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])\n",
    "hint_data[\"mean_score\"] = hint_data[\"incorrect\"].apply(lambda x: np.mean(x) if isinstance(x, list) else x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c36d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "hint_model_scores = defaultdict(list)\n",
    "\n",
    "for model in model_mapper:\n",
    "    data_model_partition = hint_data[(hint_data.solver == model)]\n",
    "    if len(data_model_partition) > 0:\n",
    "        hint_model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "\n",
    "for model in hint_model_scores:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    if len(data_model_partition) > 0:\n",
    "        hint_model_scores[model].insert(0, data_model_partition['mean_score'].mean())\n",
    "\n",
    "hint_model_scores = dict(sorted(hint_model_scores.items(), key=lambda x: x[1]))\n",
    "plot_stacked(hint_model_scores, [\"Baseline\", \"Optimized Prompt\"], save_file=\"prompt_engineering.pdf\", add_values=True, values_label_size=16,width=8, legend=True, shift_factor=0.043, colors=(\"#f1717a\", \"#a4c4e4\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecc0ec38",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.solver.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41147161",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[data.solver == 'Qwen3-4B-FT'].incorrect.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ae8f032",
   "metadata": {},
   "outputs": [],
   "source": [
    "agentic_results = {\n",
    "    'Qwen3-4B': [0.556, 0.409, 0.502, 0.431],\n",
    "     'Qwen3-235B-A22B': [0.651, 0.383, 0.565, 0.575]\n",
    "}\n",
    "\n",
    "plot(agentic_results, [\"Baseline\", \"Pass@4\", \"Best-of-4\", \"Iterative Agent\"], save_file=\"agentic_results.pdf\", add_values=True, values_label_size=13,width=8, legend=True, add_height=0.3, legend_fontsize=13.5, shift_factor=0.018, colors=(\"#cf3b3b\", \"#a4c4e4\", \"#e7969c\", \"#f2b2b2\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad05148d",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data = pd.read_csv('../data/results/sycophancy_confidence.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dc69c01",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data.groupby('solver').incorrect.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae3ed0ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data = confidence_data[~confidence_data.confidence.isna()]\n",
    "confidence_data['solver'] = confidence_data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd5c8f18",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88741418",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from statsmodels.nonparametric.kde import KDEUnivariate\n",
    "import os\n",
    "def plot_histogram_roc(df, kde_only=False, fontsize_label=18, fontsize_tick=14, fontsize_legend=16, fontsize_model_name=18, colors=(\"#84a4c4\",\"#cf3b3b\",  \"#e7969c\", \"#f2b2b2\"), filename=None):\n",
    "    get_png = {\n",
    "        \"DeepSeek-R1-Qwen3-8B\": (\"deepseek.png\", 0.0, 0.00),\n",
    "        \"gemini-2.5-pro\": (\"gemini.png\", -0.02, 0.04),\n",
    "        \"o4-mini\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT OSS 120B\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"GPT-5\": (\"openai.png\", 0.0, 0.04),\n",
    "        \"Qwen3-4B\": (\"qwen.png\", 0.04, 0),\n",
    "        \"Qwen3-235B-A22B\": (\"qwen.png\", 0.04, 0),\n",
    "        'DeepSeek-v3.1': (\"deepseek.png\", 0.0, 0.00),\n",
    "        'Grok 4': (\"xai.png\", 0.0, 0),\n",
    "        'Grok 4 Fast': (\"xai.png\", 0.0, 0)\n",
    "    }\n",
    "\n",
    "    grade_mapper= {\n",
    "        \"correct\": \"Ideal\" if len(df['true_grade'].unique()) >= 4 else \"Non-Sycophant\",\n",
    "        \"incorrect\": \"Sycophant\",\n",
    "        \"detected\": \"Detected\",\n",
    "        \"corrected\": \"Corrected\"\n",
    "    }\n",
    "\n",
    "    unique_solvers = sorted(df['solver'].unique())\n",
    "    N = len(unique_solvers)\n",
    "\n",
    "    # --- 2. CREATE FIGURE: Adjust size based on mode ---\n",
    "    rows = 1 if kde_only else 2\n",
    "    figsize = (6 * N, 7 if kde_only else 11)\n",
    "    fig, axes = plt.subplots(rows, N, figsize=figsize, squeeze=False)\n",
    "    \n",
    "    grade_types = df['true_grade'].unique()\n",
    "    \n",
    "    # --- 3. ITERATE AND PLOT ---\n",
    "    for i, solver_name in enumerate(unique_solvers):\n",
    "        solver_df = df[df['solver'] == solver_name].copy()\n",
    "\n",
    "        # --- ROW 1: Confidence Distributions ---\n",
    "        ax1 = axes[0, i]\n",
    "        \n",
    "        for j, grade in enumerate(grade_types):\n",
    "            grade_data = solver_df[solver_df['true_grade'] == grade]['confidence']\n",
    "            if not grade_data.empty:\n",
    "                kde = KDEUnivariate(grade_data.values)\n",
    "                kde.fit(bw=\"normal_reference\", clip=(0, 100)) \n",
    "                ax1.plot(kde.support, kde.density, label=f'{grade_mapper[grade]}', color=colors[j%len(colors)])\n",
    "\n",
    "        # --- Styling for Row 1 ---\n",
    "        ax1.set_xlim(0, 100)\n",
    "        ax1.tick_params(axis='x', labelsize=fontsize_tick)\n",
    "        ax1.tick_params(axis='y', labelsize=fontsize_tick)\n",
    "        leg = ax1.legend(fontsize=fontsize_legend, title='Confidence', title_fontsize=fontsize_legend)\n",
    "        leg.get_frame().set_alpha(1)\n",
    "        ax1.grid(False)\n",
    "        ax1.spines['top'].set_visible(False)\n",
    "        ax1.spines['right'].set_visible(False)\n",
    "        ax1.set_facecolor((0.97,0.97,0.97))\n",
    "\n",
    "        ax1.tick_params(labelleft=False)\n",
    "        sns.despine(left=True, bottom=True)\n",
    "        # --- ROW 2: ROC Curve (if applicable) ---\n",
    "        if not kde_only:\n",
    "            ax2 = axes[1, i]\n",
    "            y_true = ~solver_df['incorrect'] \n",
    "            y_scores = solver_df['confidence']\n",
    "            \n",
    "            fpr_initial, tpr_initial, _ = roc_curve(y_true, y_scores)\n",
    "            roc_auc_initial = auc(fpr_initial, tpr_initial)\n",
    "            \n",
    "            if roc_auc_initial < 0.5:\n",
    "                y_scores = 100 - y_scores \n",
    "                \n",
    "            fpr, tpr, _ = roc_curve(y_true, y_scores)\n",
    "            roc_auc_display = auc(fpr, tpr)\n",
    "\n",
    "            ax2.plot(fpr, tpr, color=\"#cf3b3b\", lw=3)\n",
    "            ax2.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', alpha=0.7)\n",
    "            \n",
    "            # --- Styling for Row 2 ---\n",
    "            ax2.set_xlabel(f'FPR (AUC = {roc_auc_display:.2f})', fontsize=fontsize_label)\n",
    "            ax2.set_xlim(0.0, 1.0)\n",
    "            ax2.set_ylim(0.0, 1.0)\n",
    "            ax2.grid(False)\n",
    "            ax2.spines['top'].set_visible(False)\n",
    "            ax2.spines['right'].set_visible(False)\n",
    "            \n",
    "            # Percentage-based ticks and labels\n",
    "            ticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]\n",
    "            tick_labels = [f'{int(t*100)}%' for t in ticks]\n",
    "            ax2.set_xticks(ticks)\n",
    "            ax2.set_xticklabels(tick_labels, fontsize=fontsize_tick)\n",
    "            ax2.set_yticks(ticks)\n",
    "            ax2.set_facecolor((0.97,0.97,0.97))\n",
    "            # Shared y-label and conditional y-tick labels\n",
    "            if i == 0:\n",
    "                ax2.set_ylabel('TPR', fontsize=fontsize_label)\n",
    "                ax2.set_yticklabels(tick_labels, fontsize=fontsize_tick)\n",
    "            else:\n",
    "                ax2.tick_params(labelleft=False)\n",
    "\n",
    "        # --- 4. ADD LOGOS AND NAMES (below the last axis in the column) ---\n",
    "        if model_mapper:\n",
    "            logo_ax = axes[-1, i]\n",
    "            model_name = model_mapper.get(solver_name, solver_name)\n",
    "\n",
    "            # Add model name text below the plot\n",
    "            logo_ax.text(0.5, -0.23, model_name, ha='center', va='top', \n",
    "                         fontsize=fontsize_model_name, transform=logo_ax.transAxes)\n",
    "\n",
    "            # Add logo image below the text\n",
    "            if solver_name in get_png:\n",
    "                logo_filename, _, _ = get_png[solver_name]\n",
    "                arr_img = mpimg.imread(logo_filename)\n",
    "                imagebox = OffsetImage(arr_img, zoom=0.08)\n",
    "                ab = AnnotationBbox(imagebox, (0.5, -0.4), frameon=False, \n",
    "                                    xycoords='axes fraction', box_alignment=(0.5, 1))\n",
    "                logo_ax.add_artist(ab)\n",
    "\n",
    "    # --- 5. FINAL LAYOUT ADJUSTMENTS ---\n",
    "    # Adjust spacing to make room for logos and prevent overlap\n",
    "    plt.tight_layout()\n",
    "    if filename is not None:\n",
    "        plt.savefig(filename)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb5fe2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data_simplified = confidence_data.copy()\n",
    "confidence_data_simplified['true_grade'] = np.where(confidence_data_simplified['true_grade'] == 'incorrect', 'incorrect', 'correct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d8e8a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_histogram_roc(confidence_data_simplified, fontsize_model_name=36, fontsize_label=20, fontsize_tick=18, filename='black_box_distributions.pdf', fontsize_legend=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0188594",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_histogram_roc(confidence_data, fontsize_model_name=32, fontsize_label=18, fontsize_tick=16, filename='black_box_distributions_full.pdf', kde_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3de885",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data_major = confidence_data[confidence_data.solver.isin(['Qwen3-4B', 'Qwen3-235B-A22B'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e43658a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_data_major['problem'] = confidence_data_major['problem'].apply(lambda x: x.split('-')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a8ecc9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "n=16\n",
    "def summarize(x):\n",
    "    lowest = x.iloc[:n].loc[x.iloc[:n][\"confidence\"].idxmin()]\n",
    "    highest = x.iloc[:n].loc[x.iloc[:n][\"confidence\"].idxmax()]\n",
    "    median_row = x.iloc[:n].loc[(x.iloc[:n][\"confidence\"] - x.iloc[:n][\"confidence\"].median()).abs().idxmin()]\n",
    "    return pd.Series({\n",
    "        \"pass_at_n\": x.iloc[:n].incorrect.min(),\n",
    "        \"lowest\": lowest.incorrect,\n",
    "        \"highest\": highest.incorrect,\n",
    "        \"median\": median_row.incorrect,\n",
    "    })\n",
    "\n",
    "\n",
    "confidence_scores = confidence_data_major.groupby(['solver', 'problem']).apply(summarize).reset_index().groupby('solver').agg(\n",
    "    pass_at_n=('pass_at_n', 'mean'),\n",
    "    lowest=('lowest', 'mean'),\n",
    "    highest=('highest', 'mean'),\n",
    "    median=('median', 'mean')\n",
    ").reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24282c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "confidence_model_scores = defaultdict(list)\n",
    "\n",
    "for model in confidence_scores.solver:\n",
    "    data_model_partition = data[(data.solver == model)]\n",
    "    confidence_model_scores[model].append(data_model_partition['mean_score'].mean())\n",
    "    for col in ['pass_at_n', 'lowest', 'highest']:\n",
    "        confidence_model_scores[model].append(confidence_scores[confidence_scores.solver == model].iloc[0][col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a5526f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(confidence_model_scores, [\"Baseline\", f\"Pass@{n}\", f\"Low@{n}\", f\"High@{n}\"], save_file=\"self_confidence.pdf\", add_values=True, values_label_size=12,width=9, legend_fontsize=16, legend=True, shift_factor=0.018, colors=(\"#cf3b3b\", \"#a4c4e4\", \"#e7969c\", \"#f2b2b2\"), add_height=0.35)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67215181",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_data = pd.read_csv('../data/results/sycophancy_logits.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ab981c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_data['solver'] = logits_data[\"solver\"].apply(lambda x: x if \" (\" not in x else x.split(\" (\")[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d541ccf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_data_simplified = logits_data.copy()\n",
    "logits_data_simplified['true_grade'] = np.where(logits_data_simplified['true_grade'] == 'incorrect', 'incorrect', 'correct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "add2ab30",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from statsmodels.nonparametric.kde import KDEUnivariate\n",
    "import warnings\n",
    "\n",
    "# Suppress specific warnings from statsmodels KDE for cleaner output\n",
    "warnings.filterwarnings(\"ignore\", message=\"Symmetrically trimming bandwidth\")\n",
    "\n",
    "def plot_solver_metrics(df, \n",
    "                        metrics_to_plot,\n",
    "                        fontsize_label=18, \n",
    "                        fontsize_tick=14, \n",
    "                        fontsize_legend=16, \n",
    "                        fontsize_title=22, \n",
    "                        colors=(\"#84a4c4\",\"#cf3b3b\", \"#e7969c\", \"#f2b2b2\"), \n",
    "                        filename=None,\n",
    "                        titles=['Entropy', 'Tail@10', 'Confidence']):\n",
    "    \"\"\"\n",
    "    For a SINGLE solver, plots KDE distributions and ROC curves for MULTIPLE metrics.\n",
    "\n",
    "    Args:\n",
    "        df (pd.DataFrame): DataFrame containing performance data for a SINGLE solver.\n",
    "                           It must contain 'true_grade' and 'incorrect' columns.\n",
    "        metrics_map (dict): A dictionary mapping metric column names to a boolean.\n",
    "                            The boolean should be True if a higher value is better (e.g., confidence)\n",
    "                            and False if a lower value is better (e.g., perplexity, loss).\n",
    "                            Example: {'confidence': True, 'perplexity': False}\n",
    "        ... (other formatting args) ...\n",
    "    \"\"\"\n",
    "    # --- 1. VALIDATE INPUT AND SETUP ---\n",
    "    unique_solvers = df['solver'].unique()\n",
    "    if len(unique_solvers) != 1:\n",
    "        raise ValueError(f\"This function is designed for a single solver, but found {len(unique_solvers)}: {unique_solvers}\")\n",
    "    solver_name = unique_solvers[0]\n",
    "    \n",
    "    N = len(metrics_to_plot)\n",
    "    \n",
    "    grade_mapper = {\n",
    "        \"correct\": \"Non-Sycophant\",\n",
    "        \"incorrect\": \"Sycophant\",\n",
    "    }\n",
    "    grade_types = ['correct', 'incorrect']\n",
    "\n",
    "    # --- 2. CREATE FIGURE ---\n",
    "    # 2 rows (KDE, ROC), N columns (one for each metric)\n",
    "    figsize = (6 * N, 9)\n",
    "    fig, axes = plt.subplots(2, N, figsize=figsize, squeeze=False)\n",
    "        \n",
    "    # --- 3. ITERATE OVER METRICS AND PLOT ---\n",
    "    for i, metric_col in enumerate(metrics_to_plot):\n",
    "        ax1 = axes[0, i] # Top row for KDE\n",
    "        ax2 = axes[1, i] # Bottom row for ROC\n",
    "\n",
    "        metric_title = metric_col.replace('_', ' ').title()\n",
    "\n",
    "        # --- ROW 1: Metric Distributions (KDE Plot) ---\n",
    "        for j, grade in enumerate(grade_types):\n",
    "            grade_data = df[df['true_grade'] == grade][metric_col]\n",
    "            if not grade_data.empty:\n",
    "                kde = KDEUnivariate(grade_data.values)\n",
    "                kde.fit(bw=\"normal_reference\")\n",
    "                ax1.plot(kde.support, kde.density, label=grade_mapper[grade], color=colors[j % len(colors)])\n",
    "\n",
    "        # Styling for KDE plot\n",
    "        ax1.set_title(titles[i], fontsize=fontsize_label, pad=15)\n",
    "        ax1.set_xlim(df[metric_col].min(), df[metric_col].max())\n",
    "        ax1.tick_params(axis='x', labelsize=fontsize_tick)\n",
    "        ax1.tick_params(axis='y', labelsize=fontsize_tick, labelleft=False)\n",
    "        leg = ax1.legend(fontsize=fontsize_legend)\n",
    "        leg.get_frame().set_alpha(1)\n",
    "        ax1.grid(False)\n",
    "        sns.despine(left=True, bottom=True, ax=ax1)\n",
    "        ax1.set_facecolor((0.97, 0.97, 0.97))\n",
    "\n",
    "        # --- ROW 2: ROC Curve ---\n",
    "        y_true = ~df['incorrect']\n",
    "        y_scores = df[metric_col]\n",
    "        \n",
    "        fpr_initial, tpr_initial, _ = roc_curve(y_true, y_scores)\n",
    "        roc_auc_initial = auc(fpr_initial, tpr_initial)\n",
    "            \n",
    "        if roc_auc_initial < 0.5:\n",
    "            y_scores = 100 - y_scores \n",
    "            \n",
    "        fpr, tpr, _ = roc_curve(y_true, y_scores)\n",
    "        roc_auc_display = auc(fpr, tpr)\n",
    "        \n",
    "        # Plotting for ROC\n",
    "        ax2.plot(fpr, tpr, color=\"#cf3b3b\", lw=3)\n",
    "        ax2.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--', alpha=0.7)\n",
    "        \n",
    "        # Styling for ROC plot\n",
    "        ax2.set_xlabel(f'FPR (AUC = {roc_auc_display:.2f})', fontsize=fontsize_label)\n",
    "        ax2.set_xlim(0.0, 1.0)\n",
    "        ax2.set_ylim(0.0, 1.0)\n",
    "        ticks = np.linspace(0, 1, 6)\n",
    "        tick_labels = [f'{int(t*100)}%' for t in ticks]\n",
    "        ax2.set_xticks(ticks)\n",
    "        ax2.set_xticklabels(tick_labels, fontsize=fontsize_tick)\n",
    "        ax2.set_yticks(ticks)\n",
    "        ax2.grid(False)\n",
    "        sns.despine(ax=ax2)\n",
    "        ax2.set_facecolor((0.97, 0.97, 0.97))\n",
    "\n",
    "        if i == 0:\n",
    "            ax2.set_ylabel('TPR', fontsize=fontsize_label)\n",
    "            ax2.set_yticklabels(tick_labels, fontsize=fontsize_tick)\n",
    "        else:\n",
    "            ax2.tick_params(labelleft=False)\n",
    "\n",
    "    # --- 5. FINAL LAYOUT ADJUSTMENTS ---\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust rect to make space for suptitle\n",
    "    if filename:\n",
    "        plt.savefig(filename, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f49fb45",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_solver_metrics(logits_data_simplified, ['entropy',\n",
    "       'tail_confidence', 'full_confidence'], filename='white_box_distribution.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b66d434",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_scores = {\n",
    "    \"Qwen3-4B\": [0.556, 0.334],\n",
    "    \"BrokenMath-Qwen3-4B\": [0.51, 0.379]\n",
    "}\n",
    "plot(model_scores, [\"Sycophancy\", \"Utility\"], save_file=\"finetuning.pdf\", values_label_size=24,width=10, legend=True, shift_factor=0.017, add_values=True, legend_fontsize=22, add_height=0.3, label_size=24)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14342adc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
