{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from circuits.utils import (\n",
    "    get_feature,\n",
    "    get_ae_bundle,\n",
    "    AutoEncoderBundle,\n",
    "    get_first_n_dataset_rows,\n",
    "    collect_activations_batch,\n",
    ")\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import torch\n",
    "from jaxtyping import Int, Float, jaxtyped\n",
    "from beartype import beartype\n",
    "from torch import Tensor\n",
    "import einops\n",
    "\n",
    "import importlib\n",
    "import circuits.chess_utils as chess_utils\n",
    "importlib.reload(chess_utils)\n",
    "from circuits.chess_utils import config_lookup, get_num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "autoencoder_path = \"../autoencoders/group0/ef=4_lr=1e-03_l1=1e-01_layer=5/\"\n",
    "batch_size = 25\n",
    "n_inputs = 200\n",
    "device = \"cuda\"\n",
    "model_path = \"../models/\"\n",
    "\n",
    "with open(\"data.pkl\", \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "    \n",
    "for key in data:\n",
    "    if key != \"pgn_strings\":\n",
    "        data[key] = data[key].to(device)\n",
    "\n",
    "ae_bundle = get_ae_bundle(autoencoder_path, device, data, batch_size, model_path)\n",
    "pgn_strings = data[\"pgn_strings\"]\n",
    "\n",
    "features = torch.arange(0, ae_bundle.dictionary_size, device=device)\n",
    "num_features = len(features)\n",
    "\n",
    "assert len(pgn_strings) >= n_inputs\n",
    "assert n_inputs % batch_size == 0\n",
    "\n",
    "n_iters = n_inputs // batch_size\n",
    "results = {}\n",
    "\n",
    "custom_functions = [chess_utils.board_to_piece_state, chess_utils.board_to_pin_state]\n",
    "thresholds = [0.0, 0.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# # Example setup (assuming activations_FBL is already defined)\n",
    "# n_features, batch_size, context_length = activations_FBL.shape\n",
    "# n_thresholds = len(thresholds)\n",
    "# thresholds_tensor = torch.tensor(thresholds, device=device).view(1, 1, 1, -1)  # Reshape for broadcasting\n",
    "\n",
    "# # Expand activations to match the thresholds tensor for broadcasting\n",
    "# activations_expanded = repeat(activations_FBL, 'F B L -> F B L T', T=n_thresholds)\n",
    "\n",
    "# # Vectorized thresholding\n",
    "# active_indices_FBLT = activations_expanded > thresholds_tensor\n",
    "\n",
    "# # Compute active counts for all thresholds using einops\n",
    "# active_counts_FBT = reduce(active_indices_FBLT, 'F B L T -> F T', 'sum')\n",
    "# off_counts_FBT = reduce(~active_indices_FBLT, 'F B L T -> F T', 'sum')\n",
    "\n",
    "# # Now you have the counts of active and inactive indices for each feature at each threshold\n",
    "# print(active_counts_FBT.shape)  # Shape: (n_features, n_thresholds)\n",
    "# print(off_counts_FBT.shape)     # Shape: (n_features, n_thresholds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# start = 0\n",
    "# end = 25\n",
    "# for custom_function in custom_functions:\n",
    "#     # on_tracker_FTRRC = results[custom_function.__name__]['on']\n",
    "#     # off_tracker_FTRRC = results[custom_function.__name__]['off']\n",
    "\n",
    "#     boards_BLRRC = data[custom_function.__name__][start:end]\n",
    "#     print(boards_BLRRC.shape)\n",
    "\n",
    "#     # Force CUDA synchronization before measuring memory usage\n",
    "#     torch.cuda.synchronize()\n",
    "#     memory_before = torch.cuda.memory_allocated()\n",
    "\n",
    "#     boards_TBLRRC = einops.repeat(boards_BLRRC, 'B L R1 R2 C -> T B L R1 R2 C', T=1000)\n",
    "#     boards_TBLRRC += 0.0001  # Minor operation to force physical instantiation\n",
    "#     print(boards_TBLRRC.shape)\n",
    "\n",
    "#     # Force CUDA synchronization again to ensure all operations are complete\n",
    "#     torch.cuda.synchronize()\n",
    "#     memory_after = torch.cuda.memory_allocated()\n",
    "\n",
    "#     print(f\"Memory usage before: {memory_before} bytes\")\n",
    "#     print(f\"Memory usage after: {memory_after} bytes\")\n",
    "#     print(f\"Increase in memory: {memory_after - memory_before} bytes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thresholds_T = torch.tensor(thresholds, device=device).view(-1, 1, 1,)  # Reshape for broadcasting\n",
    "print(thresholds_T.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dimension key (from https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd):\n",
    "# F  = features and minibatch size depending on the context (maybe this is stupid)\n",
    "# B = batch_size\n",
    "# L = seq length (context length)\n",
    "# T = thresholds\n",
    "# R = rows (or cols)\n",
    "# C = classes for one hot encoding\n",
    "\n",
    "\n",
    "autoencoder_path = \"../autoencoders/group0/ef=4_lr=1e-03_l1=1e-01_layer=5/\"\n",
    "batch_size = 25\n",
    "n_inputs = 200\n",
    "device = \"cuda\"\n",
    "model_path = \"../models/\"\n",
    "\n",
    "with open(\"data.pkl\", \"rb\") as f:\n",
    "    data = pickle.load(f)\n",
    "    \n",
    "for key in data:\n",
    "    if key != \"pgn_strings\":\n",
    "        data[key] = data[key].to(device)\n",
    "\n",
    "ae_bundle = get_ae_bundle(autoencoder_path, device, data, batch_size, model_path)\n",
    "pgn_strings = data[\"pgn_strings\"]\n",
    "\n",
    "features = torch.arange(0, ae_bundle.dictionary_size, device=device)\n",
    "num_features = len(features)\n",
    "\n",
    "assert len(pgn_strings) >= n_inputs\n",
    "assert n_inputs % batch_size == 0\n",
    "\n",
    "n_iters = n_inputs // batch_size\n",
    "results = {}\n",
    "\n",
    "custom_functions = [chess_utils.board_to_piece_state, chess_utils.board_to_pin_state]\n",
    "thresholds = [0.0, 0.5]\n",
    "\n",
    "thresholds_T = torch.tensor(thresholds, device=device).view(-1, 1, 1, 1)  # Reshape for broadcasting\n",
    "\n",
    "feature_batch_size = 2\n",
    "num_feature_iters = num_features // feature_batch_size\n",
    "\n",
    "for custom_function in custom_functions:\n",
    "    results[custom_function.__name__] = {}\n",
    "    config = config_lookup[custom_function.__name__]\n",
    "    num_classes = get_num_classes(config)\n",
    "\n",
    "    results[custom_function.__name__] = {}\n",
    "    on_tracker_TFRRC = torch.zeros(len(thresholds), num_features, config.num_rows, config.num_cols, num_classes).to(device)\n",
    "    results[custom_function.__name__]['on'] = on_tracker_TFRRC\n",
    "    results[custom_function.__name__]['off'] = on_tracker_TFRRC.clone()\n",
    "\n",
    "    on_counter_TF = torch.zeros(len(thresholds), num_features).to(device)\n",
    "    results[custom_function.__name__]['on_count'] = on_counter_TF\n",
    "    results[custom_function.__name__]['off_count'] = on_counter_TF.clone()\n",
    "\n",
    "for i in tqdm(range(n_iters)):\n",
    "    start = i * batch_size\n",
    "    end = (i + 1) * batch_size\n",
    "    inputs_BL = data['pgn_strings'][start:end]\n",
    "\n",
    "    all_activations_FBL, encoded_inputs = collect_activations_batch(\n",
    "        ae_bundle.model, ae_bundle.submodule, ae_bundle.context_length, inputs_BL, ae_bundle.ae, features\n",
    "    ) # activations: (features, batch_size, context_length)\n",
    "\n",
    "    # For thousands of features, this would be many GB of memory. So, we minibatch.\n",
    "    for feature in range(num_feature_iters):\n",
    "        f_start = feature * feature_batch_size\n",
    "        f_end = min((feature + 1) * feature_batch_size, num_features)\n",
    "        f_batch_size = f_end - f_start\n",
    "\n",
    "        activations_FBL = all_activations_FBL[f_start:f_end] #NOTE: Now F == feature_batch_size\n",
    "        # Maybe that's stupid and inconsistent and I should use a new letter for annotations\n",
    "        # I'll roll with it for now\n",
    "\n",
    "        # Expand activations to match the thresholds tensor for broadcasting\n",
    "        active_indices_TFBL = activations_FBL > thresholds_T\n",
    "\n",
    "        active_counts_TF = einops.reduce(active_indices_TFBL, 'T F B L -> T F', 'sum')\n",
    "        off_counts_TF = einops.reduce(~active_indices_TFBL, 'T F B L -> T F', 'sum')\n",
    "\n",
    "        for custom_function in custom_functions:\n",
    "            on_tracker_TFRRC = results[custom_function.__name__]['on']\n",
    "            off_tracker_FTRRC = results[custom_function.__name__]['off']\n",
    "\n",
    "            boards_BLRRC = data[custom_function.__name__][start:end]\n",
    "            boards_TFBLRRC = einops.repeat(boards_BLRRC, 'B L R1 R2 C -> T F B L R1 R2 C', F=f_batch_size, T=len(thresholds))\n",
    "            \n",
    "            active_boards_sum_TFRRC = einops.reduce(boards_TFBLRRC * active_indices_TFBL[:, :, :, :, None, None, None],\n",
    "                                'T F B L R1 R2 C -> T F R1 R2 C', 'sum')\n",
    "            off_boards_sum_TFRRC = einops.reduce(boards_TFBLRRC * ~active_indices_TFBL[:, :, :, :, None, None, None],\n",
    "                                'T F B L R1 R2 C -> T F R1 R2 C', 'sum')\n",
    "            \n",
    "            on_tracker_TFRRC[:, f_start:f_end, :, :, :] += active_boards_sum_TFRRC\n",
    "            off_tracker_FTRRC[:, f_start:f_end, :, :, :] += off_boards_sum_TFRRC\n",
    "            \n",
    "            results[custom_function.__name__]['on'] = on_tracker_TFRRC\n",
    "            results[custom_function.__name__]['off'] = off_tracker_FTRRC\n",
    "\n",
    "            results[custom_function.__name__]['on_count'][:, f_start:f_end] += active_counts_TF\n",
    "            results[custom_function.__name__]['off_count'][:, f_start:f_end] += off_counts_TF\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
