{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cf98c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from plotting_utils import get_history\n",
    "\n",
    "\n",
    "# Load the game database generated using skate-tournament/\n",
    "df = pd.read_csv('latest-game-database.csv')\n",
    "\n",
    "# List of players\n",
    "players = []\n",
    "\n",
    "name_mapping = {} # key: player_name, value: name you want to see in plots.\n",
    "colors = {} # key: player_name, value: color you want to see in plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75cdd06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19143dec",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 1, figsize=(14, 12), sharex=True)\n",
    "\n",
    "# First subplot: cumulative average score on own questions\n",
    "task_setters = df['task_setter'].unique()\n",
    "max_rounds = df['round_number'].max()\n",
    "last_vals_own = {}\n",
    "\n",
    "for task_setter in task_setters:\n",
    "    task_setter_df = df[df['task_setter'] == task_setter]\n",
    "    cumulative_avg = task_setter_df[task_setter].expanding().mean()\n",
    "    axes[0].plot(\n",
    "        range(len(cumulative_avg)),\n",
    "        cumulative_avg,\n",
    "        label=task_setter,\n",
    "        color=colors.get(task_setter, None),\n",
    "        linewidth=3\n",
    "    )\n",
    "    last_vals_own[task_setter] = cumulative_avg.iloc[-1] if len(cumulative_avg) > 0 else np.nan\n",
    "    last_round = len(cumulative_avg) - 1\n",
    "    last_value = cumulative_avg.iloc[-1] if len(cumulative_avg) > 0 else np.nan\n",
    "    if last_round < max_rounds:\n",
    "        axes[0].plot(\n",
    "            [last_round, max_rounds],\n",
    "            [last_value, last_value],\n",
    "            linestyle='dotted',\n",
    "            color=colors.get(task_setter, None),\n",
    "            linewidth=3\n",
    "        )\n",
    "\n",
    "axes[0].set_ylabel('Cumulative Average p(correct)', fontsize=22)\n",
    "axes[0].set_title('Cumulative Average p(correct) on Own Questions', fontsize=22)\n",
    "axes[0].grid(True)\n",
    "\n",
    "# Second subplot: cumulative average score of other models (excluding setter, oracle, potato)\n",
    "last_vals_other = {}\n",
    "for task_setter in task_setters:\n",
    "    task_setter_df = df[df['task_setter'] == task_setter]\n",
    "    answerers = [col for col in players if col not in [task_setter, 'oracle', 'potato']]\n",
    "    avg_other_models = task_setter_df[answerers].mean(axis=1)\n",
    "    cumulative_avg = avg_other_models.expanding().mean()\n",
    "    axes[1].plot(\n",
    "        range(len(cumulative_avg)),\n",
    "        cumulative_avg,\n",
    "        label=task_setter,\n",
    "        color=colors.get(task_setter, None),\n",
    "        linewidth=3\n",
    "    )\n",
    "    last_vals_other[task_setter] = cumulative_avg.iloc[-1] if len(cumulative_avg) > 0 else np.nan\n",
    "    last_round = len(cumulative_avg) - 1\n",
    "    last_value = cumulative_avg.iloc[-1] if len(cumulative_avg) > 0 else np.nan\n",
    "    if last_round < max_rounds:\n",
    "        axes[1].plot(\n",
    "            [last_round, max_rounds],\n",
    "            [last_value, last_value],\n",
    "            linestyle='dotted',\n",
    "            color=colors.get(task_setter, None),\n",
    "            linewidth=3\n",
    "        )\n",
    "\n",
    "axes[1].set_xlabel('Match Step', fontsize=22)\n",
    "axes[1].set_ylabel('Cumulative Average p(correct)', fontsize=22)\n",
    "axes[1].set_title('Cumulative Average p(correct) of Other Models (Excl. Setter) on Each Setter\\'s Questions', fontsize=22)\n",
    "axes[1].grid(True)\n",
    "\n",
    "# Sort legend by last value in the second plot\n",
    "handles, labels = axes[1].get_legend_handles_labels()\n",
    "sorted_items = sorted(zip(labels, handles), key=lambda x: last_vals_other.get(x[0], -np.inf), reverse=True)\n",
    "sorted_labels, sorted_handles = zip(*sorted_items)\n",
    "sorted_labels = [name_mapping.get(label, label) for label in sorted_labels]  # Map names for legend\n",
    "axes[1].legend(sorted_handles, sorted_labels, fontsize=15, loc='lower center', bbox_to_anchor=(0.5, -0.3), ncol=3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
