{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f25e98b1",
   "metadata": {},
   "source": [
    "### Comparing All Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "090df8be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import trueskill\n",
    "import json\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from IPython.display import display\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3a256a9",
   "metadata": {},
   "source": [
    "#### TODO\n",
    "- To reproduce this experiment, the user much run the augmentation-strategy-experimenting notebooks for all augmentation strategies\n",
    "- They can then load this data here and reproduce Figures 2 and 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e26c8e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# No Info\n",
    "\n",
    "game_data = {'no_info': {}, \n",
    "             'historical_task': {},\n",
    "             'historical_perf': {},\n",
    "             'full_personal_context': {},\n",
    "             'full_context': {}}\n",
    "\n",
    "game1 = ''\n",
    "game2 = ''\n",
    "game3 = ''\n",
    "game4 = ''\n",
    "game5 = ''\n",
    "game6 = ''\n",
    "\n",
    "game_data['no_info'] = {'game_1': game1,\n",
    "                        'game_2': game2,\n",
    "                        'game_3': game3,\n",
    "                        'game_4': game4,\n",
    "                        'game_5': game5,\n",
    "                        'game_6': game6}\n",
    "\n",
    "\n",
    "# Historical Tasks\n",
    "game1 = ''\n",
    "game2 = ''\n",
    "game3 = ''\n",
    "game4 = ''\n",
    "game5 = ''\n",
    "game6 = ''\n",
    "game_data['historical_task'] = {'game_1': game1,\n",
    "                                 'game_2': game2,\n",
    "                                 'game_3': game3,\n",
    "                                 'game_4': game4,\n",
    "                                 'game_5': game5,\n",
    "                                 'game_6': game6}\n",
    "\n",
    "# Historical Performance\n",
    "game1 = ''\n",
    "game2 = ''\n",
    "game3 = ''\n",
    "game4 = ''\n",
    "game5 = ''\n",
    "game6 = ''\n",
    "\n",
    "game_data['historical_perf'] = {'game_1': game1,\n",
    "                                 'game_2': game2,\n",
    "                                 'game_3': game3,\n",
    "                                 'game_4': game4,\n",
    "                                 'game_5': game5,\n",
    "                                 'game_6': game6}\n",
    "\n",
    "# Personal Context\n",
    "game1 = ''\n",
    "game2 = ''\n",
    "game3 = ''\n",
    "game4 = ''\n",
    "game5 = ''\n",
    "game6 = ''\n",
    "\n",
    "game_data['full_personal_context'] = {'game_1': game1,\n",
    "                                      'game_2': game2,\n",
    "                                      'game_3': game3,\n",
    "                                      'game_4': game4,\n",
    "                                      'game_5': game5,\n",
    "                                      'game_6': game6}\n",
    "# FULl context\n",
    "game1 = ''\n",
    "game2 = ''\n",
    "game3 = ''\n",
    "game4 = ''\n",
    "game5 = ''\n",
    "game6 = ''\n",
    "\n",
    "game_data['full_context'] = {'game_1': game1,\n",
    "                             'game_2': game2,\n",
    "                             'game_3': game3,\n",
    "                             'game_4': game4,\n",
    "                             'game_5': game5,\n",
    "                             'game_6': game6}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "635fd189",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "colors = {\n",
    "    'gpt-4o': '#ab68ff',       # OpenAI violet\n",
    "    'claude-3-5-sonnet-20241022': '#e49b00',  # Deep amber\n",
    "    'gemini-2.0-flash': '#1a73e8',            # Google blue\n",
    "    'claude-3-haiku-20240307': '#d45d00',     # Burnt orange\n",
    "    'claude-3-5-haiku-20241022': '#00796b',   # Teal green\n",
    "    'claude-sonnet-4-20250514': '#8e24aa',    # Purple (distinct from OpenAI’s)\n",
    "    'oracle': '#ff0000',       # Oracle red\n",
    "    'potato': '#8b5e3c',       # Earthy brown\n",
    "    'gemini':  '#1a73e8', \n",
    "    'sonnet3-5':   '#e49b00',  # Deep amber\n",
    "    'haiku3':  '#d45d00',     # Burnt orange\n",
    "}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef245930",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# Helper function to load a JSON file given a directory and filename\n",
    "def load_game_data(path):\n",
    "    with open(path, 'r') as f:\n",
    "        return json.load(f)\n",
    "\n",
    "for k,v in game_data.items():\n",
    "    loaded_data = {}\n",
    "    for g, path in v.items():\n",
    "        loaded_data[g] = load_game_data(path)\n",
    "\n",
    "    game_data[k] = loaded_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a02f7b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def get_history(SCORING, players, games):\n",
    "\n",
    "\n",
    "    env = trueskill.TrueSkill(mu=25, sigma=25/3, beta=25/6, tau=0, draw_probability=0.1)\n",
    "\n",
    "    players = set(players)\n",
    "    player_set = {player: env.Rating() for player in players}\n",
    "\n",
    "\n",
    "    # Track mu and sigma for each player at each step\n",
    "    history = {player: {'mu': [], 'sigma': []} for player in players}\n",
    "    #This is for pairwise -- UPDATING TS BY WIN/LOSE (correct/ incorrect) rather than score differential\n",
    "    for round_num in range(50):\n",
    "\n",
    "        for g in games:\n",
    "            p1, p2 = g['player_data']['player_tags']\n",
    "            for k,v in g['round_data'].items():\n",
    "                if k == str(round_num):\n",
    "                    scores = v['scores']\n",
    "                    successful_ts = [i['task_setter_player_tag'] for i in v['questions']]\n",
    "                    for ts in successful_ts:\n",
    "                        data = scores[ts]\n",
    "\n",
    "                        #if data[ts] >= 0.55: #filter by quetsions which the TS gets correct.\n",
    "                        if True:\n",
    "    \n",
    "                            # get the scores of each player in each pair\n",
    "                            score1 = data[p1]\n",
    "                            score2 = data[p2]\n",
    "                            \n",
    "                            r1 = player_set[p1]\n",
    "                            r2 = player_set[p2]\n",
    "\n",
    "                    \n",
    "                            if SCORING == 'absolute':\n",
    "                                if score1 >= 0.55 and score2 < 0.55:\n",
    "                                    new_r1, new_r2 = env.rate_1vs1(r1, r2)\n",
    "                                elif score1 <0.55 and  score2>=0.55:\n",
    "                                    new_r2, new_r1 = env.rate_1vs1(r2, r1)\n",
    "                                else:\n",
    "                                    new_r1, new_r2 = env.rate_1vs1(r1, r2, drawn=True)\n",
    "\n",
    "                            elif SCORING == 'relative':\n",
    "                                if score1 - score2 > 0.05:\n",
    "                                    new_r1, new_r2 = env.rate_1vs1(r1, r2)\n",
    "                                elif score1 - score2 < -0.05:\n",
    "                                    new_r2, new_r1 = env.rate_1vs1(r2, r1)\n",
    "                                else:\n",
    "                                    new_r1, new_r2 = env.rate_1vs1(r1, r2, drawn=True)\n",
    "                            \n",
    "                            player_set[p1] = new_r1\n",
    "                            player_set[p2] = new_r2\n",
    "\n",
    "                            history[p1]['mu'].append(player_set[p1].mu)\n",
    "                            history[p1]['sigma'].append(player_set[p1].sigma)\n",
    "                            history[p2]['mu'].append(player_set[p2].mu)\n",
    "                            history[p2]['sigma'].append(player_set[p2].sigma)\n",
    "\n",
    "    return history\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25861013",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a56c7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf76807",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43462ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "ranking_style = 'relative'\n",
    "players = list(set([p for sublist in [g['player_data']['player_tags'] for g in games] for p in sublist]))\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a99df728",
   "metadata": {},
   "source": [
    "## Fig.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cab5177d",
   "metadata": {},
   "outputs": [],
   "source": [
    "aug_strats = ['no_info', 'historical_task', 'historical_perf', 'full_personal_context', 'full_context']\n",
    "ranking_styles = ['relative', 'absolute']\n",
    "\n",
    "# Prepare to collect results for plotting\n",
    "results = {}\n",
    "\n",
    "for a_s in aug_strats:\n",
    "    results[a_s] = {}\n",
    "    for rs in ranking_styles:\n",
    "        games = list(game_data[a_s].values())\n",
    "        players = list(set([p for sublist in [g['player_data']['player_tags'] for g in games] for p in sublist]))\n",
    "        history = get_history(rs, players, games)\n",
    "        player_stats = {}\n",
    "        for player in players:\n",
    "            mus = history[player]['mu']\n",
    "            sigmas = history[player]['sigma']\n",
    "            last_100_mu = np.array(mus[-100:])\n",
    "            #last_100_sigma = np.array(sigmas[-100:])\n",
    "            avg_mu = np.mean(last_100_mu)\n",
    "            #propagated_sigma = np.sqrt(np.sum(last_100_sigma ** 2)) / 100\n",
    "            propagated_sigma = sigmas[-1]\n",
    "            player_stats[player] = (avg_mu, propagated_sigma)\n",
    "        results[a_s][rs] = player_stats\n",
    "\n",
    "# Get all unique players across all strategies for consistent ordering and coloring\n",
    "all_players = sorted({p for a_s in aug_strats for rs in ranking_styles for p in results[a_s][rs].keys()})\n",
    "\n",
    "\n",
    "\n",
    "# Define a mapping from substrings to canonical model names\n",
    "def get_base_model_name(player):\n",
    "    if 'gpt-4o' in player:\n",
    "        return 'gpt-4o'\n",
    "    elif 'gemini' in player:\n",
    "        return 'gemini'\n",
    "    elif 'sonnet' in player:\n",
    "        return 'sonnet3-5'\n",
    "    elif 'haiku' in player:\n",
    "        return 'haiku3'\n",
    "    else:\n",
    "        return player\n",
    "\n",
    "# Create a list of canonical model names in the order you want\n",
    "base_models = ['gpt-4o', 'gemini', 'sonnet3-5', 'haiku3']\n",
    "\n",
    "\n",
    "\n",
    "as_mapping = {'no_info': 'no info',\n",
    "           'historical_task': 'historical task',\n",
    "           'historical_perf': 'historical perf',\n",
    "           'full_personal_context': 'full personal context',\n",
    "           'full_context': 'full context'}\n",
    "fig, axes = plt.subplots(len(ranking_styles), len(aug_strats), figsize=(4*len(aug_strats), 4*len(ranking_styles)), sharey=True)\n",
    "\n",
    "for i, rs in enumerate(ranking_styles):\n",
    "    for j, a_s in enumerate(aug_strats):\n",
    "        ax = axes[i, j] if len(ranking_styles) > 1 else axes[j]\n",
    "        player_stats = results[a_s][rs]\n",
    "        model_mus = {bm: [] for bm in base_models}\n",
    "        model_sigmas = {bm: [] for bm in base_models}\n",
    "        for player, (mu, sigma) in player_stats.items():\n",
    "            bm = get_base_model_name(player)\n",
    "            model_mus[bm].append(mu)\n",
    "            model_sigmas[bm].append(sigma)\n",
    "        avg_mus = [np.nanmean(model_mus[bm]) if model_mus[bm] else np.nan for bm in base_models]\n",
    "        propagated_sigmas = [np.nanmean(model_sigmas[bm]) if model_sigmas[bm] else np.nan for bm in base_models]\n",
    "\n",
    "\n",
    "        bar_colors = [colors.get(bm, '#000000') for bm in base_models]  # Default to black if not found\n",
    "        bars = ax.bar(range(len(base_models)), avg_mus, yerr=propagated_sigmas, color=bar_colors, capsize=5)\n",
    "        ax.set_xticks(range(len(base_models)))\n",
    "        #ax.set_xticklabels(base_models, rotation=45, ha='right', fontsize=10)\n",
    "        ax.set_ylim(bottom=18)\n",
    "        ax.grid(axis='y', linestyle='--', alpha=0.5)\n",
    "        if j == 0:\n",
    "            ax.set_ylabel(rs.capitalize() + \" TrueSkill\", fontsize=20)\n",
    "        if i == 1:\n",
    "            ax.set_xlabel(as_mapping[a_s], fontsize=25)\n",
    "\n",
    "\n",
    "handles = [mpatches.Patch(color=colors[bm], label=bm) for bm in base_models]\n",
    "fig.legend(handles=handles, loc='lower center', ncol=len(base_models), bbox_to_anchor=(0.5, -0.1), fontsize=30)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0.03, 1, 0.97])\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "050e1f08",
   "metadata": {},
   "source": [
    "### Fig. 11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b182077",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "cum_av_score_data_with_strats = {a_s: {} for a_s in aug_strats}\n",
    "\n",
    "for i, a_s in enumerate(aug_strats):\n",
    "\n",
    "\n",
    "    data = game_data[a_s]\n",
    "\n",
    "    players = list(set([p for sublist in [g['player_data']['player_tags'] for g in data.values()] for p in sublist]))\n",
    "    cum_av_score_data = {p: [] for p in players}\n",
    "    \n",
    "\n",
    "\n",
    "    for g, g_data in data.items():\n",
    "        for player in g_data['player_data']['player_tags']:\n",
    "            cum_data = []\n",
    "            if player in g_data['player_data']['player_tags']:\n",
    "                for q in g_data['round_data'].values():\n",
    "                    if player in [i['task_setter_player_tag'] for i in q['questions']]:\n",
    "                        score = q['scores'][player][player]\n",
    "                        cum_data.append(score)\n",
    "            if cum_data:\n",
    "                cum_av_score_data[player].append(cum_data)\n",
    "\n",
    "    cum_av_score_data_with_strats[a_s] = cum_av_score_data\n",
    "\n",
    "# Define canonical model name mapping\n",
    "def get_base_model_name(player):\n",
    "    if 'gpt-4o' in player:\n",
    "        return 'gpt-4o'\n",
    "    elif 'gemini' in player:\n",
    "        return 'gemini'\n",
    "    elif 'sonnet' in player:\n",
    "        return 'sonnet3-5'\n",
    "    elif 'haiku' in player:\n",
    "        return 'haiku3'\n",
    "    else:\n",
    "        return player\n",
    "\n",
    "# Assign a consistent color to each canonical model\n",
    "base_models = ['gpt-4o', 'gemini', 'sonnet3-5', 'haiku3']\n",
    "color_list = plt.cm.tab10.colors\n",
    "base_model_color_map = {bm: color_list[i % len(color_list)] for i, bm in enumerate(base_models)}\n",
    "\n",
    "fig, axes = plt.subplots(2, 3, figsize=(18, 10), sharey=True)\n",
    "axes = axes.flatten()\n",
    "\n",
    "for idx, a_s in enumerate(aug_strats):\n",
    "    ax = axes[idx]\n",
    "    for model, score_lists in cum_av_score_data_with_strats[a_s].items():\n",
    "        base_model = get_base_model_name(model)\n",
    "        color = colors.get(base_model, None)\n",
    "        cum_avgs = []\n",
    "        valid_score_lists = [scores for scores in score_lists if scores]\n",
    "        if not valid_score_lists:\n",
    "            continue\n",
    "        min_len = min(len(scores) for scores in valid_score_lists)\n",
    "        max_len = max(len(scores) for scores in valid_score_lists)\n",
    "\n",
    "        max_len = 50\n",
    "        for scores in valid_score_lists:\n",
    "            if len(scores) >= min_len:\n",
    "                scores = scores[:min_len]\n",
    "                cum_avg = np.cumsum(scores) / (np.arange(len(scores)) + 1)\n",
    "                cum_avgs.append(cum_avg)\n",
    "        if cum_avgs:\n",
    "            cum_avgs = np.array(cum_avgs)\n",
    "            mean_curve = np.nanmean(cum_avgs, axis=0)\n",
    "            std_curve = np.nanstd(cum_avgs, axis=0)\n",
    "\n",
    "            x_main = np.arange(min_len)\n",
    "            ax.plot(x_main, mean_curve, label=base_model, color=color)\n",
    "            ax.fill_between(x_main, mean_curve - std_curve, mean_curve + std_curve, alpha=0.2, color=color)\n",
    "\n",
    "            if min_len < max_len:\n",
    "                x_dotted = np.arange(min_len, max_len)\n",
    "                y_dotted = np.full_like(x_dotted, mean_curve[-1], dtype=float)\n",
    "                std_dotted = np.full_like(x_dotted, std_curve[-1], dtype=float)\n",
    "\n",
    "                ax.plot(x_dotted, y_dotted, linestyle='dotted', color=color)\n",
    "\n",
    "    ax.set_xlabel('Question Index', fontsize=20)\n",
    "    ax.set_title(f'{a_s}', fontsize=20)\n",
    "    ax.grid(True)\n",
    "    if idx % 3 == 0:\n",
    "        ax.set_ylabel('Cumulative Average Score', fontsize=20)\n",
    "    if idx == 0:\n",
    "        handles = [mpatches.Patch(color=colors[bm], label=bm) for bm in base_models]\n",
    "        ax.legend(handles=handles, fontsize=9)\n",
    "# Hide the last subplot if not used\n",
    "if len(aug_strats) < len(axes):\n",
    "    for i in range(len(aug_strats), len(axes)):\n",
    "        axes[i].axis('off')\n",
    "\n",
    "plt.suptitle('Cumulative Average Score on Own Questions (All Aug Strats)', fontsize=24)\n",
    "plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
