{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import torch\n",
    "import einops\n",
    "from typing import Callable, Optional\n",
    "import math\n",
    "import os\n",
    "\n",
    "from circuits.utils import (\n",
    "    collect_activations_batch,\n",
    "    get_nested_folders,\n",
    "    to_device,\n",
    ")\n",
    "import circuits.eval_sae_as_classifier as eval_sae\n",
    "import circuits.chess_utils as chess_utils\n",
    "import circuits.othello_utils as othello_utils\n",
    "import circuits.test_board_reconstruction as test_board_reconstruction\n",
    "import circuits.othello_engine_utils as othello_engine_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "autoencoder_group_path = \"autoencoders/othello_layer5_ef4/\"\n",
    "autoencoder_folder = \"ef=4_lr=1e-03_l1=6e-02_layer=5/\"\n",
    "autoencoder_path = autoencoder_group_path + autoencoder_folder\n",
    "feature_labels_file = \"indexing_None_n_inputs_10000_results_feature_labels.pkl\"\n",
    "reconstruction_file = \"indexing_None_n_inputs_10000_results_reconstruction.pkl\"\n",
    "device = \"cuda\"\n",
    "# device = \"cpu\" # Not sure wtf is going on, I get this error with CPU:\n",
    "# RuntimeError: Unhandled FakeTensor Device Propagation for aten.bmm.default, found two different devices cpu:0, cpu\n",
    "device = torch.device(device)\n",
    "othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)\n",
    "\n",
    "n_inputs = 100\n",
    "batch_size = 1\n",
    "\n",
    "print(f\"Othello: {othello}\")\n",
    "model_name = eval_sae.get_model_name(othello)\n",
    "\n",
    "torch.set_printoptions(precision=2, sci_mode=False)\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# torch.set_default_tensor_type('torch.FloatTensor')  # sets default tensor type to CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(autoencoder_path + feature_labels_file, \"rb\") as f:\n",
    "    feature_labels = pickle.load(f)\n",
    "feature_labels = to_device(feature_labels, device)\n",
    "\n",
    "custom_functions = [othello_utils.games_batch_to_state_stack_mine_yours_BLRRC]\n",
    "\n",
    "data = eval_sae.construct_dataset(othello, custom_functions, n_inputs, device)\n",
    "\n",
    "data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(\n",
    "        autoencoder_path, batch_size, \"\", model_name, data, device, n_inputs, othello\n",
    "    )\n",
    "ae_bundle.buffer = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thresholds_TF11 = feature_labels[\"thresholds\"].to(device)\n",
    "alive_features_F = feature_labels[\"alive_features\"].to(device)\n",
    "num_features = len(alive_features_F)\n",
    "T, F, _, _ = thresholds_TF11.shape\n",
    "indexing_function = None\n",
    "\n",
    "if feature_labels[\"indexing_function\"] in chess_utils.supported_indexing_functions:\n",
    "    indexing_function = chess_utils.supported_indexing_functions[\n",
    "        feature_labels[\"indexing_function\"]\n",
    "    ]\n",
    "\n",
    "print(f\"Num alive features: {num_features}\")\n",
    "print(f\"Indexing function: {indexing_function}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = 0\n",
    "end = 3\n",
    "feature_batch_size = 1\n",
    "num_feature_iters = math.ceil(num_features / feature_batch_size)\n",
    "game_of_interest = 0\n",
    "move_of_interest = 30\n",
    "\n",
    "pgn_strings_BL = pgn_strings[start:end]\n",
    "encoded_inputs_BL = encoded_inputs[start:end]\n",
    "encoded_inputs_BL = torch.tensor(encoded_inputs_BL).to(device)\n",
    "\n",
    "results = test_board_reconstruction.initialize_reconstruction_dict(\n",
    "        custom_functions, thresholds_TF11.shape[0], alive_features_F, device\n",
    "    )\n",
    "\n",
    "batch_data = eval_sae.get_data_batch(\n",
    "    data, pgn_strings_BL, start, end, custom_functions, device\n",
    ")\n",
    "\n",
    "all_activations_FBL, encoded_token_inputs = collect_activations_batch(\n",
    "    ae_bundle, encoded_inputs_BL, alive_features_F\n",
    ")\n",
    "\n",
    "if indexing_function is not None:\n",
    "    all_activations_FBL, batch_data = eval_sae.apply_indexing_function(\n",
    "        pgn_strings[start:end], all_activations_FBL, batch_data, device, indexing_function\n",
    "    )\n",
    "\n",
    "constructed_boards = test_board_reconstruction.initialized_constructed_boards_dict(\n",
    "    custom_functions, batch_data, thresholds_TF11, device\n",
    ")\n",
    "\n",
    "feature_piece_counts_TF = torch.zeros(T, F, device=device)\n",
    "\n",
    "# For thousands of features, this would be many GB of memory. So, we minibatch.\n",
    "for feature in range(num_feature_iters):\n",
    "    f_start = feature * feature_batch_size\n",
    "    f_end = min((feature + 1) * feature_batch_size, num_features)\n",
    "    f_batch_size = f_end - f_start\n",
    "\n",
    "    activations_FBL = all_activations_FBL[\n",
    "        f_start:f_end\n",
    "    ]  # NOTE: Now F == feature_batch_size\n",
    "\n",
    "    results, additive_boards = test_board_reconstruction.aggregate_feature_labels(\n",
    "        results,\n",
    "        feature_labels,\n",
    "        custom_functions,\n",
    "        activations_FBL,\n",
    "        thresholds_TF11[:, f_start:f_end, :, :],\n",
    "        f_start,\n",
    "        f_end,\n",
    "        device,\n",
    "    )\n",
    "\n",
    "    additive_board_TBLRRC = additive_boards[custom_functions[0].__name__]\n",
    "\n",
    "    counts_per_threshold_T = einops.reduce(additive_board_TBLRRC[:, game_of_interest, move_of_interest, :, :, :], \"T R1 R2 C -> T\", \"sum\")\n",
    "\n",
    "    feature_piece_counts_TF[:, f_start] = counts_per_threshold_T\n",
    "\n",
    "\n",
    "    for custom_function in constructed_boards:\n",
    "        constructed_boards[custom_function] += additive_boards[custom_function]\n",
    "results = test_board_reconstruction.compare_constructed_to_true_boards(\n",
    "    results, custom_functions, constructed_boards, batch_data, device\n",
    ")\n",
    "results = test_board_reconstruction.calculate_F1_scores(results, custom_functions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_scores = results[custom_functions[0].__name__][\"f1_score\"]\n",
    "print(f1_scores)\n",
    "best_idx = f1_scores.argmax()\n",
    "print(f\"Best threshold: {best_idx}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 30\n",
    "top_20_features = torch.argsort(feature_piece_counts_TF[best_idx], descending=True)[:k]\n",
    "print(f\"Top 20 features: {top_20_features}\")\n",
    "print(f\"Top 20 feature counts: {feature_piece_counts_TF[best_idx, top_20_features]}\")\n",
    "\n",
    "print(feature_piece_counts_TF[best_idx, top_20_features[0]].sum().item())\n",
    "\n",
    "print(f\"Top 20 feature labels shape: {feature_labels[custom_functions[0].__name__][best_idx, top_20_features].shape}\")\n",
    "\n",
    "top_feature_RRC = feature_labels[custom_functions[0].__name__][best_idx, top_20_features[0]]\n",
    "\n",
    "print(top_feature_RRC.sum().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "non_classified_idx = -2\n",
    "\n",
    "top_features_FRRC = feature_labels[custom_functions[0].__name__][best_idx, top_20_features[:k]]\n",
    "\n",
    "output_RR = torch.ones((8,8)).to(device) * non_classified_idx\n",
    "\n",
    "for i in range(k):\n",
    "    topk_feature_RRC = top_features_FRRC[i]\n",
    "    topk_feature_RR = torch.argmax(topk_feature_RRC, dim=-1)\n",
    "    topk_feature_RR -= 1\n",
    "    zero_positions = torch.all(topk_feature_RRC == 0, dim=-1)\n",
    "    topk_feature_RR[zero_positions] = non_classified_idx\n",
    "\n",
    "    for r in range(8):\n",
    "        for c in range(8):\n",
    "            if topk_feature_RR[r, c] == non_classified_idx:\n",
    "                continue\n",
    "            if output_RR[r, c] == non_classified_idx:\n",
    "                output_RR[r, c] = i\n",
    "            else:\n",
    "                output_RR[r, c] = -3\n",
    "\n",
    "print(output_RR)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_feature_RR = torch.argmax(top_feature_RRC, dim=-1)\n",
    "top_feature_RR -= 1\n",
    "zero_positions = torch.all(top_feature_RRC == 0, dim=-1)\n",
    "top_feature_RR[zero_positions] = non_classified_idx\n",
    "\n",
    "print(\"Top feature:\")\n",
    "for row in top_feature_RR:\n",
    "    for value in row:\n",
    "        # Print a blank space if the value is non_classified_idx, otherwise print the value\n",
    "        if value.item() == non_classified_idx:\n",
    "            print(' ', end=' ')\n",
    "        else:\n",
    "            print(value.item(), end=' ')\n",
    "    print()  # Newline after each row"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "board_state_RRC = batch_data[custom_functions[0].__name__][game_of_interest][move_of_interest]\n",
    "board_state_RR = torch.argmax(board_state_RRC, dim=-1) - 1\n",
    "print(board_state_RR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def plot_othello_board(board):\n",
    "    \"\"\"\n",
    "    Plots an Othello board using matplotlib with a specific color lookup for different values.\n",
    "\n",
    "    Args:\n",
    "    board (torch.Tensor): A 2D tensor representing the Othello board,\n",
    "                          where 0, -1, 1, and -2 are mapped to specific colors.\n",
    "    \"\"\"\n",
    "    # Create a color map with specific colors\n",
    "    # Creating a dictionary for the color mapping\n",
    "    color_map = {-1: 'black', 0: 'grey', 1: 'white', non_classified_idx: 'yellow', -3: 'red'}\n",
    "    \n",
    "    # Replace board values with corresponding colors using a numpy vectorized operation\n",
    "    label_colors = np.vectorize(color_map.get)(board.numpy())\n",
    "\n",
    "    # Create a figure and axis for the plot\n",
    "    fig, ax = plt.subplots()\n",
    "\n",
    "    # Create a color map based on the unique labels in the board\n",
    "    unique_labels = np.unique(board)\n",
    "    colors = [color_map[label] for label in unique_labels]\n",
    "    cmap = plt.matplotlib.colors.ListedColormap(colors)\n",
    "\n",
    "    # Map board values to indices in the unique labels\n",
    "    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())\n",
    "\n",
    "    # Plot the board using imshow\n",
    "    cax = ax.imshow(board_indices, cmap=cmap)\n",
    "\n",
    "    # Create a color bar with the correct labels\n",
    "    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))\n",
    "    cbar.ax.set_yticklabels([color_map[label] for label in unique_labels])\n",
    "\n",
    "    # Set the axis to be off since we don't need it for a game board representation\n",
    "    ax.axis('off')\n",
    "\n",
    "    # Add a title to the plot\n",
    "    plt.title('Othello Board. Grey = Empty, Yellow = Not present in one hot vector')\n",
    "\n",
    "    # Show the plot\n",
    "    plt.show()\n",
    "\n",
    "plot_othello_board(board_state_RR.to('cpu'))\n",
    "plot_othello_board(top_feature_RR.to('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_feature_RR.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_comparison_othello_board(true_board_RR, recon_board_RR):\n",
    "    \"\"\"\n",
    "    Plots a comparison of the true and reconstructed Othello boards using matplotlib.\n",
    "\n",
    "    Args:\n",
    "    true_board (torch.Tensor): A 2D tensor representing the true Othello board.\n",
    "    recon_board (torch.Tensor): A 2D tensor representing the reconstructed Othello board.\n",
    "    \"\"\"\n",
    "\n",
    "    # Create a color map based on the unique labels in the board\n",
    "    color_map = {non_classified_idx: 'grey', -1: 'red', 0: 'white', 1: 'blue'}\n",
    "    cmap = plt.matplotlib.colors.ListedColormap(color_map.values())\n",
    "\n",
    "    board_indices = recon_board_RR.numpy() - min(color_map.keys())\n",
    "\n",
    "\n",
    "    # Create a figure and axis for the plot\n",
    "    fig, ax = plt.subplots()\n",
    "    plt.imshow(board_indices, cmap=cmap)\n",
    "    # add circles on each square with black borders\n",
    "    for i in range(8):\n",
    "        for j in range(8):\n",
    "            circle = plt.Circle((j, i), 0.3, color=color_map[true_board_RR[i, j].item()], fill=True)\n",
    "            circle_edges = plt.Circle((j, i), 0.3, color='black', fill=False)\n",
    "            plt.gca().add_artist(circle)\n",
    "            plt.gca().add_artist(circle_edges)\n",
    "            if recon_board_RR[i, j].item() != non_classified_idx:\n",
    "                plt.gca().add_patch(plt.Rectangle((j-0.5, i-0.5), 1, 1, fill=False, edgecolor='black', lw=2))\n",
    "    plt.title('True Board and Reconstruction by Single Feature')\n",
    "    # plt.savefig('othello_board_comparison.png', dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "plot_comparison_othello_board(board_state_RR.to('cpu'), top_feature_RR.to('cpu'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_feature_attribution_board(board: torch.Tensor, k: int):\n",
    "    # Creating a color gradient for top k features\n",
    "    gradient_colors = plt.cm.viridis(np.linspace(0, 1, k))\n",
    "    \n",
    "    # Define special colors for no significant features and conflicts\n",
    "    special_colors = {non_classified_idx: 'grey', -3: 'red'}\n",
    "    \n",
    "    # Combine all colors into a list\n",
    "    colors = [special_colors.get(str(int(label)), gradient_colors[int(label)]) if label >= 0 else special_colors[int(label)]\n",
    "              for label in np.unique(board)]\n",
    "    \n",
    "    # Create a color map from the list of colors\n",
    "    cmap = plt.matplotlib.colors.ListedColormap(colors)\n",
    "    \n",
    "    # Map the board values to indices in the unique labels\n",
    "    unique_labels = np.unique(board)\n",
    "    board_indices = np.vectorize(lambda x: np.where(unique_labels == x)[0][0])(board.numpy())\n",
    "\n",
    "    # Plot the board\n",
    "    fig, ax = plt.subplots()\n",
    "    cax = ax.imshow(board_indices, cmap=cmap)\n",
    "    cbar = fig.colorbar(cax, ticks=range(len(unique_labels)))\n",
    "    cbar.ax.set_yticklabels([f'Feature {int(label)}' if label >= 0 else ('No feature' if label == non_classified_idx else 'Multiple features') for label in unique_labels])\n",
    "\n",
    "    # Hide the axes\n",
    "    # ax.axis('off')\n",
    "\n",
    "    # Title\n",
    "    plt.title(f'Feature Attribution Board, top {k}')\n",
    "    # plt.savefig(f'feature_attribution_board_top{k}.png', dpi=300)\n",
    "    # Display the plot\n",
    "    plt.show()\n",
    "\n",
    "plot_feature_attribution_board(output_RR.to('cpu'), k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
