{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cf98c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "\n",
    "from plotting_utils import get_history\n",
    "\n",
    "\n",
    "# Load the game database generated using skate-tournament/\n",
    "df = pd.read_csv('latest-game-database.csv')\n",
    "\n",
    "df_without_sonnet_35_or_sonnet_4 = df[(df['task_setter'] != 'claude-3-5-sonnet-20241022') & (df['task_setter'] != 'claude-sonnet-4-20250514')]\n",
    "df_without_sonnet4 = df[df['task_setter'] != 'claude-sonnet-4-20250514']\n",
    "\n",
    "# List of players\n",
    "players = []\n",
    "\n",
    "name_mapping = {} # key: player_name, value: name you want to see in plots.\n",
    "colors = {} # key: player_name, value: color you want to see in plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75cdd06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19143dec",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d755f395",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Want the same plot as above, but first plot is (all model rankings except sonnet 3.5 and sonnet4)\n",
    "# then show that Sonnet 3.4 and Sonnet 4 come along, and can be ranked without even asking questions\n",
    "# then introduce sonnet 3.5 questions\n",
    "# then introduce Sonnet 4 questions\n",
    "\n",
    "\n",
    "# Prepare a summary table of final mu values for each model in each history set\n",
    "first_set = ['gpt-4o', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022']\n",
    "second_set = ['gpt-4o', 'claude-3-5-sonnet-20241022', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022', 'claude-sonnet-4-20250514']\n",
    "third_set = second_set\n",
    "fourth_set = second_set\n",
    "\n",
    "\n",
    "history = get_history('relative', first_set, df_without_sonnet_35_or_sonnet_4)\n",
    "second_history = get_history('relative', second_set, df_without_sonnet_35_or_sonnet_4)\n",
    "third_history = get_history('relative', third_set, df_without_sonnet4)\n",
    "fourth_history = get_history('relative', fourth_set, df)\n",
    "\n",
    "\n",
    "def get_final_mu(history_dict):\n",
    "    return {player: round(vals['mu'][-1],4) if len(vals['mu']) > 0 else np.nan for player, vals in history_dict.items()}\n",
    "\n",
    "mu_table = pd.DataFrame({\n",
    "    'Initial': get_final_mu(history),\n",
    "    'path1': get_final_mu(second_history),\n",
    "    'path2': get_final_mu(third_history),\n",
    "    'path3': get_final_mu(fourth_history),\n",
    "})\n",
    "\n",
    "mu_table = mu_table  # Transpose for easier reading (sets as rows, models as columns)\n",
    "mu_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbc02f8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "# Define a color for each player, consistent across plots\n",
    "\n",
    "\n",
    "def get_final_sigma(history_dict):\n",
    "    return {player: vals['sigma'][-1] if len(vals['sigma']) > 0 else np.nan for player, vals in history_dict.items()}\n",
    "\n",
    "# Prepare sigma tables matching mu_table\n",
    "sigma_table = pd.DataFrame({\n",
    "    'Initial': get_final_sigma(history),\n",
    "    'path1': get_final_sigma(second_history),\n",
    "    'path2': get_final_sigma(third_history),\n",
    "    'path3': get_final_sigma(fourth_history),\n",
    "})\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(24, 6), sharey=True)\n",
    "\n",
    "plot_titles = ['Initial', '+ Sonnet 3-5', '+ Sonnet-4', 'All Models']\n",
    "plot_titles = ['Initial', 'Weaker Models Score Stronger Models', '+Sonnet-3.5 to Game', '+Sonnet-4 to Game']\n",
    "plot_titles = ['(a)', '(b)', '(c)', '(d)']\n",
    "mu_cols = ['Initial', 'path1', 'path2', 'path3']\n",
    "\n",
    "# Manually define the fixed order of models\n",
    "fixed_order = ['claude-sonnet-4-20250514', 'claude-3-5-sonnet-20241022', 'gemini-2.0-flash','gpt-4o', 'claude-3-5-haiku-20241022', 'claude-3-haiku-20240307']  # <-- replace with your desired order\n",
    "fixed_order = fixed_order[::-1]\n",
    "\n",
    "y_pos = np.arange(len(fixed_order))\n",
    "\n",
    "for ax, col, title in zip(axes, mu_cols, plot_titles):\n",
    "    if col in mu_table.columns:\n",
    "        mu_values = mu_table[col].reindex(fixed_order)\n",
    "        sigma_values = sigma_table[col].reindex(fixed_order)\n",
    "        bar_colors = [colors.get(player, '#cccccc') for player in fixed_order]\n",
    "\n",
    "        ax.barh(\n",
    "            y=y_pos,\n",
    "            width=mu_values.values,\n",
    "            xerr=sigma_values.values,\n",
    "            color=bar_colors,\n",
    "            edgecolor='black',\n",
    "            linewidth=0.7,\n",
    "            alpha=0.9,\n",
    "            capsize=5\n",
    "        )\n",
    "\n",
    "        ax.set_title(title, fontsize=30)\n",
    "        ax.set_xlabel(r'$\\mu$', fontsize=30)\n",
    "        ax.set_xlim(left=18, right=35)\n",
    "        ax.set_yticks(y_pos)\n",
    "\n",
    "        if ax == axes[0]:\n",
    "            ax.set_yticklabels(fixed_order, fontsize=11)\n",
    "            ax.set_ylabel('Model', fontsize=25)\n",
    "        else:\n",
    "            ax.set_yticklabels([''] * len(fixed_order))\n",
    "            ax.set_ylabel('')\n",
    "\n",
    "# Create a single legend outside the plots\n",
    "final_order = mu_table['path3'].sort_values(ascending=False).index\n",
    "final_order = fixed_order\n",
    "final_order = final_order[::-1]  # Reverse to match the order in the plots\n",
    "ordered_handles = [Patch(facecolor=colors.get(p, '#cccccc'), edgecolor='black', label=p) for p in final_order]\n",
    "final_order = [name_mapping.get(p, p) for p in final_order]  # Map names for legend\n",
    "fig.legend(handles=ordered_handles, labels=list(final_order), loc='lower center', bbox_to_anchor=(0.5, -0.35), frameon=False, fontsize=25, title='', ncols=2)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 0.95, 1])\n",
    "plt.show()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
