{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using processed_results.csv created in /circuits/f1_analysis.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "import einops\n",
    "import importlib\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.express as px\n",
    "from matplotlib.colors import Normalize\n",
    "\n",
    "import circuits.analysis as analysis\n",
    "import circuits.eval_sae_as_classifier as eval_sae\n",
    "import circuits.chess_utils as chess_utils\n",
    "import circuits.utils as utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('processed_results_group14v2_F1above0.9.csv')\n",
    "for col in df.columns:\n",
    "    if 'piece' in col:\n",
    "        print(col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reconstruction\n",
    "reconstruction_all_col = 'best_f1_score_per_square_average'\n",
    "reconstruction_basic_col = 'board_to_piece_state_best_f1_score_per_square'\n",
    "reconstruction_8x8_col = 'best_f1_score_per_square_only_board_state_average'\n",
    "reconstruction_1x1_col = 'best_f1_score_per_square_only_binary_average'\n",
    "reconstruction_cols = [reconstruction_all_col, reconstruction_basic_col, reconstruction_8x8_col, reconstruction_1x1_col]\n",
    "reconstruction_col_names = ['All reconstructions', 'Basic reconstructions', '8x8 reconstructions', '1x1 reconstructions']\n",
    "\n",
    "# Coverage\n",
    "coverage_all_col = 'best_custom_metric_per_square_average'\n",
    "coverage_basic_col = 'board_to_piece_state_best_custom_metric_per_square'\n",
    "coverage_8x8_col = 'best_custom_metric_per_square_only_board_state_average'\n",
    "coverage_1x1_col = 'best_custom_metric_per_square_only_binary_average'\n",
    "coverage_cols = [coverage_all_col, coverage_basic_col, coverage_8x8_col, coverage_1x1_col]\n",
    "coverage_col_names = ['All coverages', 'Basic coverages', '8x8 coverages', '1x1 coverages']\n",
    "\n",
    "# basic: board to piece state\n",
    "# only board state: [\"board_to_piece_state\", \"board_to_piece_color_state\", \"board_to_threat_state\", \"board_to_legal_moves_state\", \"board_to_pseudo_legal_move_state\"]\n",
    "# only binary: NOT [\"board_to_piece_binary\", \"board_to_piece_color_binary\", \"board_to_threat_binary\", \"board_to_legal_moves_binary\", \"board_to_pseudo_legal_move_binary\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get unique trainer types\n",
    "unique_trainers = df['trainer_class'].unique()\n",
    "\n",
    "# create a dictionary mapping trainer types to marker shapes\n",
    "trainer_markers = dict(zip(unique_trainers, ['o', 's', '^', 'D']))\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "\n",
    "def plot_custom_metric(color_column: str, color_column_name: str , ymin: float = 0.0, save: bool = False, cbar_scale_fixed: bool = False):\n",
    "    # create the scatter plot\n",
    "    fig, ax = plt.subplots(figsize=(10, 6))\n",
    "    ax.grid(zorder=0, alpha=0.5)\n",
    "\n",
    "\n",
    "    metric_1 = \"l0\"\n",
    "    metric_1_label = r'$L_0$ (Lower is sparser)'\n",
    "    metric_2 = \"frac_recovered\"\n",
    "    metric_2_label = 'Loss Recovered (Fidelity)'\n",
    "\n",
    "    if cbar_scale_fixed:\n",
    "        norm = Normalize(vmin=0.5, vmax=0.8)\n",
    "    else:\n",
    "        norm = Normalize(vmin=df[color_column][df[metric_2]>=ymin].min(), vmax=df[color_column][df[metric_2]>=ymin].max())\n",
    "\n",
    "    # plot data points for each trainer type separately\n",
    "    for trainer, marker in trainer_markers.items():\n",
    "        trainer_data = df[df['trainer_class'].str.contains(trainer)]\n",
    "        ax.scatter(trainer_data[metric_1], trainer_data[metric_2], c=trainer_data[color_column], cmap='viridis', marker=marker, s=80, label=trainer, norm=norm, zorder=10)\n",
    "\n",
    "    # add colorbar\n",
    "    cbar = fig.colorbar(ax.collections[0], ax=ax)\n",
    "    cbar.set_label(color_column_name)\n",
    "\n",
    "    # set labels and title\n",
    "    ax.set_xlabel(metric_1_label)\n",
    "    # ax.set_ylabel(f'1 - {metric_2_label}')\n",
    "    ax.set_ylabel(metric_2_label)\n",
    "    ax.set_title(f'{color_column_name} vs Standard Metrics')\n",
    "\n",
    "    # addnd\n",
    "    lg = ax.legend(title='Trainer Type', loc='lower right')\n",
    "    lg.get_frame().set_alpha(1)\n",
    "\n",
    "    # # set x range\n",
    "    # ax.set_xlim(0, 1000)\n",
    "    ax.set_ylim(ymin, 1.001)\n",
    "    # ax.set_yscale('log')\n",
    "    \n",
    "    # display the plot\n",
    "    if save:\n",
    "        plt.savefig(f'{color_column_name.lower().replace(\" \", \"_\")}_vs_standard_metrics.png', dpi=200)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_custom_metric(reconstruction_all_col, 'All reconstructions', ymin=0.98)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_cbar = False\n",
    "set_ymin = 0.98\n",
    "\n",
    "for col, name in zip(reconstruction_cols, reconstruction_col_names):\n",
    "    plot_custom_metric(col, name, ymin=set_ymin, save=False, cbar_scale_fixed=fix_cbar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_cbar = False\n",
    "set_ymin = 0.98\n",
    "\n",
    "for col, name in zip(coverage_cols, coverage_col_names):\n",
    "    plot_custom_metric(col, name, ymin=set_ymin, save=False, cbar_scale_fixed=fix_cbar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(df['best_f1_score_per_square_only_board_state_average'] - df['best_f1_score_per_square_average']).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
