{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cf98c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from matplotlib.patches import Patch\n",
    "\n",
    "\n",
    "from plotting_utils import get_history\n",
    "\n",
    "\n",
    "# Load the game database generated using skate-tournament/\n",
    "df = pd.read_csv('latest-game-database.csv')\n",
    "\n",
    "# List of players\n",
    "players = []\n",
    "\n",
    "name_mapping = {} # key: player_name, value: name you want to see in plots.\n",
    "colors = {} # key: player_name, value: color you want to see in plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75cdd06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19143dec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare a summary table of final mu values for each model in each history set\n",
    "first_set = ['gpt-4o', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022']\n",
    "second_set = ['gpt-4o', 'claude-3-5-sonnet-20241022', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022']\n",
    "third_set = ['gpt-4o', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022', 'claude-sonnet-4-20250514']\n",
    "fourth_set = ['gpt-4o', 'claude-3-5-sonnet-20241022', 'gemini-2.0-flash', 'claude-3-haiku-20240307', 'claude-3-5-haiku-20241022', 'claude-sonnet-4-20250514']\n",
    "\n",
    "\n",
    "history = get_history('relative', first_set, df)\n",
    "second_history = get_history('relative', second_set, df)\n",
    "third_history = get_history('relative', third_set, df)\n",
    "fourth_history = get_history('relative', fourth_set, df)\n",
    "\n",
    "\n",
    "def get_final_mu(history_dict):\n",
    "    return {player: round(vals['mu'][-1],4) if len(vals['mu']) > 0 else np.nan for player, vals in history_dict.items()}\n",
    "\n",
    "mu_table = pd.DataFrame({\n",
    "    'Initial': get_final_mu(history),\n",
    "    'path1': get_final_mu(second_history),\n",
    "    'path2': get_final_mu(third_history),\n",
    "    'path3': get_final_mu(fourth_history),\n",
    "})\n",
    "\n",
    "mu_table = mu_table  # Transpose for easier reading (sets as rows, models as columns)\n",
    "mu_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d755f395",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "# Define a color for each player, consistent across plots\n",
    "\n",
    "\n",
    "def get_final_sigma(history_dict):\n",
    "    return {player: vals['sigma'][-1] if len(vals['sigma']) > 0 else np.nan for player, vals in history_dict.items()}\n",
    "\n",
    "# Prepare sigma tables matching mu_table\n",
    "sigma_table = pd.DataFrame({\n",
    "    'Initial': get_final_sigma(history),\n",
    "    'path1': get_final_sigma(second_history),\n",
    "    'path2': get_final_sigma(third_history),\n",
    "    'path3': get_final_sigma(fourth_history),\n",
    "})\n",
    "\n",
    "fig = plt.figure(figsize=(18, 8))\n",
    "gs = fig.add_gridspec(3, 3, width_ratios=[1, 1, 1], height_ratios=[1, 1, 1], wspace=0.5, hspace=0.5)\n",
    "\n",
    "# Axes for each plot\n",
    "ax_init = fig.add_subplot(gs[1, 0])      # Initial in the middle left\n",
    "ax_path1 = fig.add_subplot(gs[0, 1])     # Path1 up right\n",
    "ax_path2 = fig.add_subplot(gs[2, 1])     # Path2 down right\n",
    "ax_path3 = fig.add_subplot(gs[:, 2])     # Path3 full right\n",
    "\n",
    "axes = [ax_init, ax_path1, ax_path2, ax_path3]\n",
    "plot_titles = ['Initial', '+ Sonnet 3-5', '+ Sonnet-4', 'All Models']\n",
    "mu_cols = ['Initial', 'path1', 'path2', 'path3']\n",
    "\n",
    "# Plot each subplot with same bar height\n",
    "for ax, col, title in zip(axes, mu_cols, plot_titles):\n",
    "    if col in mu_table.columns:\n",
    "        sorted_mu = mu_table[col].sort_values(ascending=True)\n",
    "        sorted_sigma = sigma_table[col].reindex(sorted_mu.index)\n",
    "        bar_colors = [colors.get(player, '#cccccc') for player in sorted_mu.index]\n",
    "\n",
    "        ax.barh(\n",
    "            y=sorted_mu.index,\n",
    "            width=sorted_mu.values,\n",
    "            xerr=sorted_sigma.values,\n",
    "            color=bar_colors,\n",
    "            edgecolor='black',\n",
    "            linewidth=0.7,\n",
    "            alpha=0.9,\n",
    "            capsize=5\n",
    "        )\n",
    "        ax.set_title(title, fontsize=24)\n",
    "        ax.set_xlabel('Skill (Mu)', fontsize=20)\n",
    "        if ax == ax_init:\n",
    "            ax.set_ylabel('Model', fontsize=20)\n",
    "        else:\n",
    "            ax.set_ylabel('')\n",
    "        # Remove bar labels\n",
    "        ax.set_xlim(left=18, right=35)  # Set x-axis limits for all plots\n",
    "        ax.tick_params(axis='y', labelsize=11)\n",
    "        ax.set_yticklabels([])\n",
    "\n",
    "# Create a single legend outside the plots\n",
    "legend_handles = [Patch(facecolor=colors[p], edgecolor='black', label=p) for p in mu_table.index]\n",
    "\n",
    "# Use the sorted order of the final chart (path3)\n",
    "final_order = mu_table['path3'].sort_values(ascending=False).index\n",
    "legend_handles = [Patch(facecolor=colors.get(p, '#cccccc'), edgecolor='black', label=p) for p in final_order]\n",
    "\n",
    "# Order legend handles by the order in the final plot (path3 column, descending)\n",
    "final_order = mu_table['path3'].sort_values(ascending=False).index\n",
    "ordered_handles = [Patch(facecolor=colors.get(p, '#cccccc'), edgecolor='black', label=p) for p in final_order]\n",
    "\n",
    "fig.legend(handles=ordered_handles, labels=list(final_order), loc='center left', bbox_to_anchor=(0.92, 0.5), frameon=False, fontsize=20, title='')\n",
    "\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
