{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0cf98c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "# Load the game database generated using skate-tournament/\n",
    "df = pd.read_csv('latest-game-database.csv')\n",
    "filtered_df = df[df.apply(lambda row: row[row['task_setter']] > 0.55, axis=1)]\n",
    "\n",
    "# List of players\n",
    "players = []\n",
    "\n",
    "name_mapping = {} # key: player_name, value: name you want to see in plots.\n",
    "colors = {} # key: player_name, value: color you want to see in plots."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c75cdd06",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "# Get the list of model columns (excluding 'task_setter')\n",
    "model_columns = ['gpt-4o',\n",
    " 'claude-3-5-sonnet-20241022',\n",
    " 'gemini-2.0-flash',\n",
    " 'claude-3-haiku-20240307',\n",
    " 'claude-3-5-haiku-20241022',\n",
    " 'claude-sonnet-4-20250514',]\n",
    "\n",
    "# Initialize an empty DataFrame to store the results\n",
    "results_matrix = pd.DataFrame(index=model_columns, columns=model_columns)\n",
    "\n",
    "# Calculate the metric for each model and task_setter combination\n",
    "for task_setter_model in model_columns:\n",
    "    # Filter the DataFrame for questions set by the current task_setter_model\n",
    "    df_filtered = df[df['task_setter'] == task_setter_model]\n",
    "\n",
    "    for current_model in model_columns:\n",
    "        # Average score of the current_model on questions set by task_setter_model\n",
    "        avg_score_current_model = df_filtered[current_model].mean()\n",
    "\n",
    "        # Average score of all other models on questions set by task_setter_model\n",
    "        other_models = [m for m in model_columns if m != current_model]\n",
    "        if other_models:\n",
    "            avg_score_other_models = df_filtered[other_models].mean().mean()\n",
    "        else:\n",
    "            avg_score_other_models = 0 # No other models to compare\n",
    "\n",
    "        # Compute the difference\n",
    "        difference = avg_score_current_model - avg_score_other_models\n",
    "        results_matrix.loc[current_model, task_setter_model] = difference\n",
    "\n",
    "# Convert the results_matrix values to float for heatmap\n",
    "results_matrix = results_matrix.astype(float)\n",
    "\n",
    "\n",
    "\n",
    "# Initialize an empty DataFrame to store the results\n",
    "results_matrix = pd.DataFrame(index=model_columns, columns=model_columns)\n",
    "\n",
    "# Calculate the metric for each model and task_setter combination\n",
    "for task_setter_model in model_columns:\n",
    "    # Filter the DataFrame for questions set by the current task_setter_model\n",
    "    df_filtered = filtered_df[filtered_df['task_setter'] == task_setter_model]\n",
    "\n",
    "    for current_model in model_columns:\n",
    "        # Average score of the current_model on questions set by task_setter_model\n",
    "        avg_score_current_model = df_filtered[current_model].mean()\n",
    "\n",
    "        # Average score of all other models on questions set by task_setter_model\n",
    "        other_models = [m for m in model_columns if m != current_model]\n",
    "        if other_models:\n",
    "            avg_score_other_models = df_filtered[other_models].mean().mean()\n",
    "        else:\n",
    "            avg_score_other_models = 0 # No other models to compare\n",
    "\n",
    "        # Compute the difference\n",
    "        difference = avg_score_current_model - avg_score_other_models\n",
    "        results_matrix.loc[current_model, task_setter_model] = difference\n",
    "\n",
    "# Convert the results_matrix values to float for heatmap\n",
    "results_matrix_filtered = results_matrix.astype(float)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4ef352c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# A mapping for renaming models for better plot readability\n",
    "name_map = {\n",
    "    'gpt-4o': 'gpt-4o',\n",
    "    'claude-3-5-sonnet-20241022': 'sonnet-3-5',\n",
    "    'gemini-2.0-flash': 'gemini-2.0-flash',\n",
    "    'claude-3-haiku-20240307': 'haiku-3',\n",
    "    'claude-3-5-haiku-20241022': 'haiku-3-5',\n",
    "    'claude-sonnet-4-20250514': 'sonnet-4'\n",
    "}\n",
    "\n",
    "def create_heatmap(df: pd.DataFrame, title: str, filename: str):\n",
    "    \"\"\"\n",
    "    Generates a heatmap from a DataFrame, highlighting the largest element\n",
    "    in each row and replacing NaN values with \"no questions\" at cell level.\n",
    "    If an entire row or column is empty (all NaNs), its label will be\n",
    "    changed to \"no data\".\n",
    "\n",
    "    Args:\n",
    "        df (pd.DataFrame): The input DataFrame containing numerical data.\n",
    "                           Rows and columns will be used as labels.\n",
    "        title (str): The title for the heatmap plot.\n",
    "        filename (str): The filename (e.g., 'my_heatmap.png') to save the plot.\n",
    "    \"\"\"\n",
    "    plt.figure(figsize=(12, 8)) # Increased figure size for better readability with larger fonts\n",
    "\n",
    "    # Rename the rows and columns using name_map for better readability\n",
    "    df_mapped = df.rename(index=name_map, columns=name_map)\n",
    "\n",
    "\n",
    "    # Create a copy for modification of labels without altering original df_mapped for computations\n",
    "    df_plot = df_mapped.copy()\n",
    "\n",
    "    dummy_value = -1\n",
    "    df_plot_filled = df_plot.fillna(dummy_value)\n",
    "\n",
    "\n",
    "    # Create an annotation DataFrame with formatted strings\n",
    "    # Numbers are formatted to two decimal places, NaN values become \"no questions\"\n",
    "    annot_df = df_plot.applymap(lambda x: '{:.2f}'.format(x) if pd.notna(x) else 'no valid qs')\n",
    "\n",
    "    # Get the column index (label) of the maximum value in each row from the *original* mapped data.\n",
    "    # We use df_mapped for max_indices calculation because 'no data' labels might interfere\n",
    "    # and we only want to highlight actual numerical rows.\n",
    "    max_indices = df_mapped.idxmax(axis=1)\n",
    "\n",
    "    # Plot the heatmap\n",
    "    ax = sns.heatmap(\n",
    "        df_plot_filled,\n",
    "        annot=annot_df,\n",
    "        fmt=\"\",\n",
    "        cmap='coolwarm',\n",
    "        linewidths=.5,\n",
    "        linecolor='black',\n",
    "        cbar=False,  # Remove the colorbar}\n",
    "        annot_kws={\"fontsize\": 16, \"color\": \"black\"},\n",
    "    )\n",
    "\n",
    "\n",
    "    # Highlight the largest element in each row\n",
    "    # Iterate over original mapped labels to correctly find original numerical index for highlight\n",
    "    for i, row_label_original in enumerate(df_mapped.index):\n",
    "        max_col_label = max_indices.loc[row_label_original]\n",
    "        \n",
    "        # Only highlight if a maximum value was found (i.e., row is not all NaNs)\n",
    "        # and if the row itself was not entirely 'no data' (as it's now renamed)\n",
    "        if pd.notna(max_col_label) and not df_mapped.loc[row_label_original].isna().all():\n",
    "            j = df_mapped.columns.get_loc(max_col_label) # Get the numerical column index from original mapped df\n",
    "            rect = plt.Rectangle((j, i), 1, 1, fill=False, edgecolor='black', lw=5) # Thinner border\n",
    "            ax.add_patch(rect)\n",
    "\n",
    "    #plt.title(title, fontsize=20) # Added fontsize for the main title\n",
    "    plt.xticks(rotation=45, ha='right', fontsize=22) # Increased x-axis label font size\n",
    "    plt.yticks(rotation=0, fontsize=22)             # Increased y-axis label font size\n",
    "    plt.xlabel('Asking Model', fontsize=25)  # Increased x-axis label font size\n",
    "    plt.ylabel('Answering Model', fontsize=25)  # Increased y-axis label font\n",
    "    plt.tight_layout()                 # Adjust plot to prevent labels from overlapping\n",
    "    plt.show()                        # Close the plot to free up memory\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Create heatmaps using the enhanced function\n",
    "create_heatmap(results_matrix, 'Unfiltered Results Heatmap', 'unfiltered_heatmap_nicer.png')\n",
    "create_heatmap(results_matrix_filtered, 'Filtered Results Heatmap', 'filtered_heatmap_nicer.png')\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
