{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "log_name = \"group1_results.json\"\n",
    "\n",
    "with open(log_name, \"r\") as f:\n",
    "    data = json.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In results logged from the `sae_stats_collection.py` script, we have a hyperparameters key, and in this case, 32 separate autoencoders that we evaluated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Total results: {len(data)}\")\n",
    "print()\n",
    "\n",
    "example_key = \"\"\n",
    "for key in data.keys():\n",
    "    # print(key)\n",
    "    if \"syntax\" in data[key].keys():\n",
    "        example_key = key\n",
    "\n",
    "\n",
    "print(f\"Example key: {example_key}\")\n",
    "print(\"\\nTop level keys:\")\n",
    "for key in data[example_key].keys():\n",
    "    print(key)\n",
    "\n",
    "print(\"\\nSyntax keys:\")\n",
    "for key in data[example_key][\"syntax\"].keys():\n",
    "    print(key)\n",
    "\n",
    "print(\"\\nBoard keys:\")\n",
    "for key in data[example_key][\"board\"].keys():\n",
    "    print(key)\n",
    "\n",
    "print(\"\\nExample statistics:\")\n",
    "print(data[example_key][\"syntax\"][\"find_num_indices\"])\n",
    "print(data[example_key][\"board\"][\"board_to_piece_state\"])\n",
    "\n",
    "print(\"\\nEval results:\")\n",
    "print(data[example_key][\"eval_results\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have 3 syntax filters: find_num_indices, find_spaces_indices, and find_dots_indices. We also have four board state filters: board_to_piece_state, board_to_threat_state, board_to_check_state, and board_to_pin_state. I added the below cell to find the average board state filter match count and average syntax filter match count. A simple average probably isn't the best way of doing this, this is just the first thing I tried."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "syntax_keys = {}\n",
    "for func in data[example_key][\"syntax\"].keys():\n",
    "    for var in data[example_key][\"syntax\"][func].keys():\n",
    "        syntax_keys[var] = 0.0\n",
    "    break\n",
    "\n",
    "board_keys = {}\n",
    "\n",
    "for func in data[example_key][\"board\"].keys():\n",
    "    for var in data[example_key][\"board\"][func].keys():\n",
    "\n",
    "        if type(data[example_key][\"board\"][func][var]) == list:\n",
    "            continue\n",
    "        if type(data[example_key][\"board\"][func][var]) == dict:\n",
    "            continue\n",
    "\n",
    "        board_keys[var] = 0.0\n",
    "    break\n",
    "\n",
    "print(\"Syntax keys: \", syntax_keys)\n",
    "print(\"Board keys: \", board_keys)\n",
    "\n",
    "for key in data.keys():\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    \n",
    "    syntax_dict = syntax_keys.copy()\n",
    "    for func in data[key][\"syntax\"].keys():\n",
    "        for var in data[key][\"syntax\"][func].keys():\n",
    "            if var not in syntax_dict.keys():\n",
    "                continue\n",
    "            syntax_dict[var] += data[key][\"syntax\"][func][var]\n",
    "    for var in syntax_dict.keys():\n",
    "        syntax_dict[var] /= len(data[key][\"syntax\"].keys())\n",
    "\n",
    "    board_dict = board_keys.copy()\n",
    "    for func in data[key][\"board\"].keys():\n",
    "        for var in data[key][\"board\"][func].keys():\n",
    "            if var not in board_dict.keys():\n",
    "                continue\n",
    "            board_dict[var] += data[key][\"board\"][func][var]\n",
    "    for var in board_dict.keys():\n",
    "        board_dict[var] /= len(data[key][\"board\"].keys())\n",
    "\n",
    "    data[key][\"syntax\"][\"syntax_average\"] = syntax_dict\n",
    "    data[key][\"board\"][\"board_average\"] = board_dict\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All 32 of the sparse autoencoders share the same keys and nested key structure. For example, there's ['syntax']['find_dots_indices']['syntax_match_idx_count']. This below cell creates `vars`, where we basically collapse the nesting, so it would now be `find_dots_indices_syntax_match_idx_count`. This is so we can easily look for correlations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vars = {}\n",
    "\n",
    "for key in data.keys():\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "\n",
    "    vars[key] = {}\n",
    "\n",
    "    for func in data[key][\"syntax\"].keys():\n",
    "        for var in data[key][\"syntax\"][func].keys():\n",
    "            vars[key][f\"{func + '_' + var}\"] = data[key][\"syntax\"][func][var]\n",
    "    for func in data[key][\"board\"].keys():\n",
    "        for var in data[key][\"board\"][func].keys():\n",
    "            if type(data[key][\"board\"][func][var]) == int or type(data[key][\"board\"][func][var]) == float:\n",
    "                vars[key][f\"{func + '_' + var}\"] = data[key][\"board\"][func][var]\n",
    "    for var in data[key][\"eval_results\"].keys():\n",
    "        vars[key][f\"{var}\"] = data[key][\"eval_results\"][var]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"We have {len(vars[example_key])} variables for each example\\n\")\n",
    "\n",
    "for key in vars[example_key].keys():\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have 43 unique variables per SAE. This makes plots crowded. We make `simple_vars`, which filters out most of the above keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_vars = {}\n",
    "\n",
    "for key in vars.keys():\n",
    "    simple_vars[key] = {}\n",
    "    for var in vars[key].keys():\n",
    "        if \"syntax\" in var or \"board\" in var:\n",
    "\n",
    "            if \"nonzero\" in var or \"dim_count\" in var:\n",
    "                continue\n",
    "\n",
    "            if \"syntax_match\" not in var and \"pattern_match\" not in var:\n",
    "                continue\n",
    "\n",
    "            if \"syntax_average\" in var or \"board_average\" in var:\n",
    "                simple_vars[key][var] = vars[key][var]\n",
    "        elif \"find\" in var:\n",
    "            continue\n",
    "        elif \"loss_original\" in var:\n",
    "            continue\n",
    "        else:\n",
    "            simple_vars[key][var] = vars[key][var]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"We now have {len(simple_vars[example_key])} variables for each example\\n\")\n",
    "\n",
    "for key, value in simple_vars[example_key].items():\n",
    "    print(key, value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next two cells find the SAE in a variety of configurations, like expansion factor 4 or layer 0, with the maximum and minimum scores for a key of interest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_matching_dict = {}\n",
    "\n",
    "max_matching_dict[\"_\"] = {\"name\": \"\", \"count\": 0} # any\n",
    "max_matching_dict[\"ef16\"] = {\"name\": \"\", \"count\": 0}\n",
    "max_matching_dict[\"ef8\"] = {\"name\": \"\", \"count\": 0}\n",
    "max_matching_dict[\"ef4\"] = {\"name\": \"\", \"count\": 0}\n",
    "max_matching_dict[\"layer=0\"] = {\"name\": \"\", \"count\": 0}\n",
    "max_matching_dict[\"layer=5\"] = {\"name\": \"\", \"count\": 0}\n",
    "\n",
    "key_of_interest = \"board_average_pattern_match_count\"\n",
    "# key_of_interest = \"syntax_average_syntax_match_idx_count\"\n",
    "\n",
    "for key in simple_vars.keys():\n",
    "\n",
    "    for name in max_matching_dict.keys():\n",
    "        if name in key:\n",
    "            if simple_vars[key][key_of_interest] > max_matching_dict[name]['count']:\n",
    "                max_matching_dict[name]['count'] = simple_vars[key][key_of_interest]\n",
    "                max_matching_dict[name]['name'] = key\n",
    "    \n",
    "for name in max_matching_dict.keys():\n",
    "    print(name, max_matching_dict[name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_matching_dict = {}\n",
    "\n",
    "min_matching_dict[\"_\"] = {\"name\": \"\", \"count\": 1e6} # any\n",
    "min_matching_dict[\"ef16\"] = {\"name\": \"\", \"count\": 1e6}\n",
    "min_matching_dict[\"ef8\"] = {\"name\": \"\", \"count\": 1e6}\n",
    "min_matching_dict[\"ef4\"] = {\"name\": \"\", \"count\": 1e6}\n",
    "min_matching_dict[\"layer=0\"] = {\"name\": \"\", \"count\": 1e6}\n",
    "min_matching_dict[\"layer=5\"] = {\"name\": \"\", \"count\": 1e6}\n",
    "\n",
    "key_of_interest = \"board_average_pattern_match_count\"\n",
    "# key_of_interest = \"syntax_average_syntax_match_idx_count\"\n",
    "\n",
    "for key in simple_vars.keys():\n",
    "\n",
    "    for name in min_matching_dict.keys():\n",
    "        if name in key:\n",
    "            if simple_vars[key][key_of_interest] < min_matching_dict[name]['count']:\n",
    "                min_matching_dict[name]['count'] = simple_vars[key][key_of_interest]\n",
    "                min_matching_dict[name]['name'] = key\n",
    "    \n",
    "for name in min_matching_dict.keys():\n",
    "    print(name, min_matching_dict[name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we print out stats of any autoencoder we are interested in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ae_of_interest = 'autoencoders/ef4_20k_resample/ef=4_lr=1e-03_l1=1e-01_layer=5/'\n",
    "# ae_of_interest = 'autoencoders/ef8/ef=8_lr=1e-04_l1=1e-04_layer=5/'\n",
    "key_of_interest = \"pattern_match_count\"\n",
    "\n",
    "print(\"Average board match count: \", data[ae_of_interest]['board']['board_average'][key_of_interest])\n",
    "print(\"Average syntax match count: \", data[ae_of_interest]['syntax']['syntax_average']['syntax_match_idx_count'])\n",
    "\n",
    "for key, value in data[ae_of_interest]['eval_results'].items():\n",
    "    print(key, value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And now we just have a bunch of plots of our data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# Step 1: Create DataFrame from vars dictionary\n",
    "df = pd.DataFrame.from_dict(simple_vars, orient='index')\n",
    "\n",
    "# Step 2: Calculate the correlation matrix\n",
    "correlation_matrix = df.corr()\n",
    "\n",
    "# Display the correlation matrix\n",
    "print(correlation_matrix)\n",
    "\n",
    "# Optional Step 3: Visualize the correlation matrix using seaborn\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(correlation_matrix, annot=False, fmt=\".2f\", cmap='coolwarm')\n",
    "plt.title('Correlation Matrix of Reported Results')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Step 1: Create DataFrame from vars dictionary\n",
    "df = pd.DataFrame.from_dict(simple_vars, orient='index')\n",
    "\n",
    "# Step 2: Calculate the correlation matrix\n",
    "correlation_matrix = df.corr()\n",
    "\n",
    "# Selecting specific variables for the x-axis\n",
    "selected_vars = ['syntax_average_syntax_match_idx_count', 'board_average_pattern_match_count']\n",
    "\n",
    "# Filter the correlation matrix\n",
    "# We use `.loc` to specify that we want all rows (all variables) but only columns for the two selected variables\n",
    "filtered_correlation_matrix = correlation_matrix.loc[:, selected_vars]\n",
    "\n",
    "# Step 3: Visualize the correlation matrix\n",
    "plt.figure(figsize=(5, 10))  # Adjusted the figure size for better visualization\n",
    "sns.heatmap(filtered_correlation_matrix, annot=True, fmt=\".2f\", cmap='coolwarm', yticklabels=correlation_matrix.index)\n",
    "plt.title('Correlation Matrix for Selected Variables')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "y = []\n",
    "\n",
    "key1 = \"l2_loss\"\n",
    "key2 = \"pattern_match_count\"\n",
    "\n",
    "for key in data:\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    x.append(data[key][\"eval_results\"][key1])\n",
    "    y.append(data[key][\"board\"][\"board_to_piece_state\"][key2])\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.title(f'{key1} vs. {key2}')\n",
    "plt.xlabel(key1)\n",
    "plt.ylabel(key2)\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "y = []\n",
    "\n",
    "key1 = \"loss_reconstructed\"\n",
    "key2 = \"pattern_match_count\"\n",
    "\n",
    "for key in data:\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    x.append(data[key][\"eval_results\"][key1])\n",
    "    y.append(data[key][\"board\"][\"board_to_piece_state\"][key2])\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.title(f'{key1} vs. {key2}')\n",
    "plt.xlabel(key1)\n",
    "plt.ylabel(key2)\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "y = []\n",
    "\n",
    "key1 = \"syntax_match_idx_count\"\n",
    "key2 = \"pattern_match_count\"\n",
    "\n",
    "for key in data:\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    x.append(data[key][\"syntax\"]['find_num_indices'][key1])\n",
    "    y.append(data[key][\"board\"][\"board_to_piece_state\"][key2])\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.title(f'{key1} vs. {key2}')\n",
    "plt.xlabel(key1)\n",
    "plt.ylabel(key2)\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "y = []\n",
    "\n",
    "key1 = \"l0\"\n",
    "key2 = \"pattern_match_count\"\n",
    "\n",
    "for key in data:\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    x.append(data[key][\"eval_results\"][key1])\n",
    "    y.append(data[key][\"board\"][\"board_to_piece_state\"][key2])\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.title(f'{key1} vs. {key2}')\n",
    "plt.xlabel(key1)\n",
    "plt.ylabel(key2)\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = []\n",
    "y = []\n",
    "\n",
    "key1 = \"l0\"\n",
    "key2 = \"syntax_match_idx_count\"\n",
    "\n",
    "for key in data:\n",
    "    if \"syntax\" not in data[key].keys():\n",
    "        continue\n",
    "    x.append(data[key][\"eval_results\"][key1])\n",
    "    y.append(data[key][\"syntax\"][\"find_num_indices\"][key2])\n",
    "\n",
    "plt.scatter(x, y)\n",
    "plt.title(f'{key1} vs. {key2}')\n",
    "plt.xlabel(key1)\n",
    "plt.ylabel(key2)\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
