{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import pandas as pd\n",
    "from typing import Callable\n",
    "import torch\n",
    "\n",
    "import circuits.eval_sae_as_classifier as eval_sae\n",
    "import circuits.analysis as analysis\n",
    "import circuits.test_board_reconstruction as test_board_reconstruction\n",
    "import circuits.get_eval_results as get_eval_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For multi-GPU evaluation\n",
    "from collections import deque\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "from circuits.utils import to_device\n",
    "\n",
    "N_GPUS = 1\n",
    "RESOURCE_STACK = deque([f\"cuda:{i}\" for i in range(N_GPUS)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_dataframe(custom_functions: list[Callable]) -> pd.DataFrame:\n",
    "\n",
    "    constant_columns = [\n",
    "        \"autoencoder_group_path\",\n",
    "        \"autoencoder_path\",\n",
    "        \"reconstruction_file\",\n",
    "        \"trainer_class\",\n",
    "        \"sae_class\",\n",
    "        \"eval_sae_n_inputs\",\n",
    "        \"eval_results_n_inputs\",\n",
    "        \"board_reconstruction_n_inputs\",\n",
    "        \"l0\",\n",
    "        \"l1_loss\",\n",
    "        \"l2_loss\",\n",
    "        \"frac_alive\",\n",
    "        \"frac_variance_explained\",\n",
    "        \"cossim\",\n",
    "        \"l2_ratio\",\n",
    "        \"loss_original\",\n",
    "        \"loss_reconstructed\",\n",
    "        \"loss_zero\",\n",
    "        \"frac_recovered\",\n",
    "        \"num_alive_features\"\n",
    "    ]\n",
    "\n",
    "    template_columns = [\n",
    "        \"board_reconstruction_board_count\",\n",
    "        \"num_squares\",\n",
    "        \"best_idx\",\n",
    "        \"zero_L0\",\n",
    "        \"zero_f1_score_per_class\",\n",
    "        \"zero_f1_score_per_square\",\n",
    "        \"best_L0\",\n",
    "        \"best_f1_score_per_class\",\n",
    "        \"best_f1_score_per_square\",\n",
    "        \"zero_num_true_positive_squares\",\n",
    "        \"best_num_true_positive_squares\",\n",
    "        \"zero_num_false_positive_squares\",\n",
    "        \"best_num_false_positive_squares\",\n",
    "        \"zero_multiple_classes\",\n",
    "        \"best_multiple_classes\",\n",
    "        \"zero_num_true_and_false_positive_squares\",\n",
    "        \"best_num_true_and_false_positive_squares\",\n",
    "        \"high_precision_counts_per_T\",\n",
    "        \"high_precision_and_recall_counts_per_T\",\n",
    "        # \"zero_percent_active_classifiers\",\n",
    "        # \"best_percent_active_classifiers\",\n",
    "        # \"zero_classifiers_per_token\",\n",
    "        # \"best_classifiers_per_token\",\n",
    "        # \"zero_classified_per_token\",\n",
    "        # \"best_classified_per_token\",\n",
    "    ]\n",
    "\n",
    "    # Generate the custom columns based on the custom functions\n",
    "    custom_columns = [\n",
    "        f\"{func.__name__}_{template_col}\"\n",
    "        for func in custom_functions\n",
    "        for template_col in template_columns\n",
    "    ]\n",
    "    \n",
    "\n",
    "    # Combine the constant columns with the custom columns\n",
    "    all_columns = constant_columns + custom_columns\n",
    "\n",
    "    # Create and return the DataFrame with the combined columns\n",
    "    return pd.DataFrame(columns=all_columns)\n",
    "\n",
    "def append_results(\n",
    "    eval_results: dict,\n",
    "    aggregate_results: dict,\n",
    "    board_reconstruction_results: dict,\n",
    "    misc_stats: dict,\n",
    "    custom_functions: list[Callable],\n",
    "    df: pd.DataFrame,\n",
    "    autoencoder_group_path: str,\n",
    "    autoencoder_path: str,\n",
    "    reconstruction_file: str,\n",
    ") -> pd.DataFrame:\n",
    "    \n",
    "    # Initialize the new row with constant fields\n",
    "    new_row = {\n",
    "        \"autoencoder_group_path\": autoencoder_group_path,\n",
    "        \"autoencoder_path\": autoencoder_path,\n",
    "        \"reconstruction_file\": reconstruction_file,\n",
    "        \"trainer_class\": aggregate_results[\"trainer_class\"],\n",
    "        \"sae_class\": aggregate_results[\"sae_class\"],\n",
    "        \"eval_sae_n_inputs\": aggregate_results[\"hyperparameters\"]['n_inputs'],\n",
    "        \"eval_results_n_inputs\": eval_results[\"hyperparameters\"]['n_inputs'],\n",
    "        \"board_reconstruction_n_inputs\": board_reconstruction_results[\"hyperparameters\"]['n_inputs'],\n",
    "        \"l0\": eval_results['eval_results'][\"l0\"],\n",
    "        \"l1_loss\": eval_results['eval_results'][\"l1_loss\"],\n",
    "        \"l2_loss\": eval_results['eval_results'][\"l2_loss\"],\n",
    "        \"frac_alive\": eval_results['eval_results'][\"frac_alive\"],\n",
    "        \"frac_variance_explained\": eval_results['eval_results'][\"frac_variance_explained\"],\n",
    "        \"cossim\": eval_results['eval_results'][\"cossim\"],\n",
    "        \"l2_ratio\": eval_results['eval_results'][\"l2_ratio\"],\n",
    "        \"loss_original\": eval_results['eval_results'][\"loss_original\"],\n",
    "        \"loss_reconstructed\": eval_results['eval_results'][\"loss_reconstructed\"],\n",
    "        \"loss_zero\": eval_results['eval_results'][\"loss_zero\"],\n",
    "        \"frac_recovered\": eval_results['eval_results'][\"frac_recovered\"],\n",
    "        \"num_alive_features\": board_reconstruction_results[\"alive_features\"].shape[0],\n",
    "    }\n",
    "    \n",
    "    for custom_function in custom_functions:\n",
    "        function_name = custom_function.__name__\n",
    "        best_idx = board_reconstruction_results[function_name][\"f1_score_per_square\"].argmax()\n",
    "\n",
    "        # Add the custom fields to the new row\n",
    "        new_row[f\"{function_name}_board_reconstruction_board_count\"] = board_reconstruction_results[function_name][\"num_boards\"]\n",
    "        new_row[f\"{function_name}_num_squares\"] = board_reconstruction_results[function_name][\"num_squares\"]\n",
    "        new_row[f\"{function_name}_best_idx\"] = best_idx.item()\n",
    "        new_row[f\"{function_name}_zero_L0\"] = board_reconstruction_results[\"active_per_token\"][0].item()\n",
    "        new_row[f\"{function_name}_best_L0\"] = board_reconstruction_results[\"active_per_token\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_f1_score_per_class\"] = board_reconstruction_results[function_name][\"f1_score_per_class\"][0].item()\n",
    "        new_row[f\"{function_name}_best_f1_score_per_class\"] = board_reconstruction_results[function_name][\"f1_score_per_class\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_f1_score_per_square\"] = board_reconstruction_results[function_name][\"f1_score_per_square\"][0].item()\n",
    "        new_row[f\"{function_name}_best_f1_score_per_square\"] = board_reconstruction_results[function_name][\"f1_score_per_square\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_num_true_positive_squares\"] = board_reconstruction_results[function_name][\"num_true_positive_squares\"][0].item()\n",
    "        new_row[f\"{function_name}_best_num_true_positive_squares\"] = board_reconstruction_results[function_name][\"num_true_positive_squares\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_num_false_positive_squares\"] = board_reconstruction_results[function_name][\"num_false_positive_squares\"][0].item()\n",
    "        new_row[f\"{function_name}_best_num_false_positive_squares\"] = board_reconstruction_results[function_name][\"num_false_positive_squares\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_multiple_classes\"] = board_reconstruction_results[function_name][\"num_multiple_classes\"][0].item()\n",
    "        new_row[f\"{function_name}_best_multiple_classes\"] = board_reconstruction_results[function_name][\"num_multiple_classes\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_zero_num_true_and_false_positive_squares\"] = board_reconstruction_results[function_name][\"num_true_and_false_positive_squares\"][0].item()\n",
    "        new_row[f\"{function_name}_best_num_true_and_false_positive_squares\"] = board_reconstruction_results[function_name][\"num_true_and_false_positive_squares\"][best_idx].item()\n",
    "        new_row[f\"{function_name}_high_precision_counts_per_T\"] = misc_stats[function_name][\"high_precision_counts_per_T\"]\n",
    "        new_row[f\"{function_name}_high_precision_and_recall_counts_per_T\"] = misc_stats[function_name][\"high_precision_and_recall_counts_per_T\"]\n",
    "        # These following columns aren't currently used\n",
    "        # new_row[f\"{function_name}_zero_percent_active_classifiers\"] = (\n",
    "        #     board_reconstruction_results[function_name][\"classifiers_per_token\"][0]\n",
    "        #     / board_reconstruction_results[\"active_per_token\"][0]\n",
    "        # ).item()\n",
    "        # new_row[f\"{function_name}_best_percent_active_classifiers\"] = (\n",
    "        #     board_reconstruction_results[function_name][\"classifiers_per_token\"][best_idx]\n",
    "        #     / board_reconstruction_results[\"active_per_token\"][best_idx]\n",
    "        # ).item()\n",
    "        # new_row[f\"{function_name}_zero_classifiers_per_token\"] = board_reconstruction_results[function_name][\"classifiers_per_token\"][0].item()\n",
    "        # new_row[f\"{function_name}_best_classifiers_per_token\"] = board_reconstruction_results[function_name][\"classifiers_per_token\"][best_idx].item()\n",
    "        # new_row[f\"{function_name}_zero_classified_per_token\"] = board_reconstruction_results[function_name][\"classified_per_token\"][0].item()\n",
    "        # new_row[f\"{function_name}_best_classified_per_token\"] = board_reconstruction_results[function_name][\"classified_per_token\"][best_idx].item()\n",
    "\n",
    "\n",
    "    new_row_df = pd.DataFrame([new_row])\n",
    "\n",
    "    # Check if the original DataFrame is empty\n",
    "    if df.empty:\n",
    "        df = new_row_df\n",
    "    else:\n",
    "        df = pd.concat([df, new_row_df], ignore_index=True)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Basically, just set `autoencoder_group_paths` and various hyperparameters and run it. If you already ran, for example, `eval_sae_as_classifier` and don't want to run it again, set `run_eval_sae` to False. Note that in this case, `eval_results_n_inputs` must match in order for it to load the file saved from the previous run.\n",
    "\n",
    "By default, we `save_results`, which means each of the 4 functions saves a `.pkl` file. By default, we also aggregate and format some of the results into a csv `output_file`. If you already have results `.pkl` files and want a csv, you can set all `run_...` to False, and it will load the results and put them into a csv. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "\n",
    "importlib.reload(eval_sae)\n",
    "importlib.reload(analysis)\n",
    "importlib.reload(test_board_reconstruction)\n",
    "importlib.reload(get_eval_results)\n",
    "import circuits.chess_utils as chess_utils\n",
    "\n",
    "importlib.reload(chess_utils)\n",
    "\n",
    "# NOTE: This script makes a major assumption here: That all autoencoders in a given group are trained on chess XOR Othello\n",
    "# We do this so we don't have to reconstruct the dataset for each autoencoder in the group\n",
    "# autoencoder_group_paths = [\"../autoencoders/othello_layer5_ef4/\"]\n",
    "# autoencoder_group_paths = [\"../autoencoders/chess_layer5/\"]\n",
    "# autoencoder_group_paths = [\"../autoencoders/chess_layer5_large_sweep/\"]\n",
    "autoencoder_group_paths = [\"../autoencoders/group-2024-05-07/\"]\n",
    "\n",
    "\n",
    "eval_sae_n_inputs = 1000\n",
    "batch_size = 100\n",
    "#device = \"cuda\"\n",
    "model_path = \"../models/\"\n",
    "\n",
    "eval_results_n_inputs = 1000\n",
    "board_reconstruction_n_inputs = 1000\n",
    "\n",
    "analysis_high_threshold = 0.95\n",
    "analysis_low_threshold = 0.1\n",
    "analysis_significance_threshold = 10\n",
    "\n",
    "run_eval_results = True  # We don't check for this as eval_results are pretty quick to collect\n",
    "\n",
    "# To skip any of the following steps, set the corresponding variable to False\n",
    "# The results must have been saved previously\n",
    "run_eval_sae = True\n",
    "run_analysis = True\n",
    "run_board_reconstruction = True\n",
    "\n",
    "mask = True\n",
    "save_results = True\n",
    "use_separate_test_set = True\n",
    "\n",
    "dataset_size = max(eval_sae_n_inputs, eval_results_n_inputs, board_reconstruction_n_inputs)\n",
    "\n",
    "# Dataset size must be larger than eval_results_n_inputs or we reach the end of the data stream\n",
    "if dataset_size == eval_results_n_inputs:\n",
    "    dataset_size *= 2\n",
    "\n",
    "for autoencoder_group_path in autoencoder_group_paths:\n",
    "    othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)\n",
    "    \n",
    "    indexing_functions = eval_sae.get_recommended_indexing_functions(othello)\n",
    "    indexing_function = indexing_functions[0]\n",
    "\n",
    "    custom_functions = eval_sae.get_recommended_custom_functions(othello)\n",
    "    # Example custom functions\n",
    "    custom_functions = eval_sae.get_all_chess_functions(othello)\n",
    "\n",
    "    model_name = eval_sae.get_model_name(othello)\n",
    "\n",
    "    # If True, precompute everything and store it in VRAM. Faster, but far higher memory usage\n",
    "    # If True, VRAM scales with batch size and n_inputs\n",
    "    # If False, VRAM scales with batch size only\n",
    "    precompute = True\n",
    "\n",
    "    device = RESOURCE_STACK.pop()\n",
    "    print(\"Constructing statistics aggregation dataset\")\n",
    "    train_data = eval_sae.construct_dataset(\n",
    "        othello, custom_functions, dataset_size, split=\"train\",device=device, models_path=model_path, precompute_dataset=precompute\n",
    "    )\n",
    "    if use_separate_test_set:\n",
    "        print(\"Constructing test dataset\")\n",
    "        test_data = eval_sae.construct_dataset(\n",
    "            othello, custom_functions, dataset_size, split=\"test\",device=device, models_path=model_path, precompute_dataset=precompute\n",
    "        )\n",
    "    else:\n",
    "        test_data = train_data\n",
    "    RESOURCE_STACK.append(device)\n",
    "    del device\n",
    "\n",
    "    folders = eval_sae.get_nested_folders(autoencoder_group_path)\n",
    "\n",
    "    def full_eval_pipeline(autoencoder_path):\n",
    "\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "        df = initialize_dataframe(custom_functions)\n",
    "        \n",
    "        # For debugging\n",
    "        # if \"ef=4_lr=1e-03_l1=1e-01_layer=5\" not in autoencoder_path:\n",
    "        #     return df\n",
    "\n",
    "        # Grab a GPU off the stack to use\n",
    "        device = RESOURCE_STACK.pop()\n",
    "\n",
    "        # If this is set, everything below should be reproducible\n",
    "        # Then we can just save results from 1 run, make optimizations, and check that the results are the same\n",
    "        # The determinism is only needed for getting activations from the activation buffer for finding alive features\n",
    "        torch.manual_seed(0)\n",
    "        eval_results = get_eval_results.get_evals(\n",
    "            autoencoder_path,\n",
    "            eval_results_n_inputs,\n",
    "            device,\n",
    "            model_path,\n",
    "            model_name,\n",
    "            to_device(train_data.copy(), device),\n",
    "            othello=othello,\n",
    "            save_results=save_results,\n",
    "        )\n",
    "\n",
    "        expected_aggregation_output_location = eval_sae.get_output_location(\n",
    "            autoencoder_path, n_inputs=eval_sae_n_inputs, indexing_function=indexing_function\n",
    "        )\n",
    "\n",
    "        if run_eval_sae:\n",
    "            print(\"Aggregating\", autoencoder_path)\n",
    "            aggregation_results = eval_sae.aggregate_statistics(\n",
    "                custom_functions=custom_functions,\n",
    "                autoencoder_path=autoencoder_path,\n",
    "                n_inputs=eval_sae_n_inputs,\n",
    "                batch_size=batch_size,\n",
    "                device=device,\n",
    "                model_path=model_path,\n",
    "                model_name=model_name,\n",
    "                data=to_device(train_data.copy(), device),\n",
    "                indexing_function=indexing_function,\n",
    "                othello=othello,\n",
    "                save_results=save_results,\n",
    "                precomputed=precompute,\n",
    "            )\n",
    "        else:\n",
    "            with open(expected_aggregation_output_location, \"rb\") as f:\n",
    "                aggregation_results = pickle.load(f)\n",
    "\n",
    "        expected_feature_labels_output_location = expected_aggregation_output_location.replace(\n",
    "            \"results.pkl\", \"feature_labels.pkl\"\n",
    "        )\n",
    "        if run_analysis:\n",
    "            feature_labels, misc_stats = analysis.analyze_results_dict(\n",
    "                aggregation_results,\n",
    "                output_path=expected_feature_labels_output_location,\n",
    "                device=device,\n",
    "                high_threshold=analysis_high_threshold,\n",
    "                low_threshold=analysis_low_threshold,\n",
    "                significance_threshold=analysis_significance_threshold,\n",
    "                verbose=False,\n",
    "                print_results=False,\n",
    "                save_results=save_results,\n",
    "                mask=mask,\n",
    "            )\n",
    "        else:\n",
    "            with open(expected_feature_labels_output_location, \"rb\") as f:\n",
    "                feature_labels = pickle.load(f)\n",
    "\n",
    "        expected_reconstruction_output_location = expected_aggregation_output_location.replace(\n",
    "            \"results.pkl\", \"reconstruction.pkl\"\n",
    "        )\n",
    "\n",
    "        if run_board_reconstruction:\n",
    "            print(\"Testing board reconstruction\")\n",
    "            board_reconstruction_results = test_board_reconstruction.test_board_reconstructions(\n",
    "                custom_functions=custom_functions,\n",
    "                autoencoder_path=autoencoder_path,\n",
    "                feature_labels=feature_labels,\n",
    "                output_file=expected_reconstruction_output_location,\n",
    "                n_inputs=board_reconstruction_n_inputs,\n",
    "                batch_size=batch_size,\n",
    "                device=device,\n",
    "                model_name=model_name,\n",
    "                data=to_device(test_data.copy(), device),\n",
    "                othello=othello,\n",
    "                print_results=False,\n",
    "                save_results=save_results, \n",
    "                precomputed=precompute,\n",
    "                mask=mask,\n",
    "            )\n",
    "        else:\n",
    "            with open(expected_reconstruction_output_location, \"rb\") as f:\n",
    "                board_reconstruction_results = pickle.load(f)\n",
    "\n",
    "        \n",
    "        df = append_results(\n",
    "            eval_results,\n",
    "            aggregation_results,\n",
    "            board_reconstruction_results,\n",
    "            misc_stats,\n",
    "            custom_functions,\n",
    "            df,\n",
    "            autoencoder_group_path,\n",
    "            autoencoder_path,\n",
    "            expected_reconstruction_output_location,\n",
    "        )\n",
    "\n",
    "        print(\"Finished\", autoencoder_path)\n",
    "\n",
    "        # Save the dataframe after each autoencoder so we don't lose data if the script crashes\n",
    "        output_file = autoencoder_path + \"/\" + \"results.csv\"\n",
    "        df.to_csv(output_file)\n",
    "\n",
    "        # Put the GPU back on the stack after we're done\n",
    "        RESOURCE_STACK.append(device)\n",
    "        return df\n",
    "\n",
    "    dfs = Parallel(n_jobs=N_GPUS, require=\"sharedmem\")(\n",
    "        delayed(full_eval_pipeline)(autoencoder_path) for autoencoder_path in folders\n",
    "    )\n",
    "    pd.concat(dfs, axis=0, ignore_index=True).to_csv(autoencoder_group_path + \"results.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Example of gathering top k contexts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import torch\n",
    "# import circuits.chess_interp as chess_interp\n",
    "# importlib.reload(chess_interp)\n",
    "\n",
    "# torch.set_grad_enabled(False)\n",
    "\n",
    "# autoencoder_group_path = autoencoder_group_paths[0]\n",
    "\n",
    "# othello = eval_sae.check_if_autoencoder_is_othello(autoencoder_group_path)\n",
    "\n",
    "# indexing_functions = eval_sae.get_recommended_indexing_functions(othello)\n",
    "# indexing_function = indexing_functions[0]\n",
    "\n",
    "# custom_functions = eval_sae.get_recommended_custom_functions(othello)\n",
    "\n",
    "# model_name = eval_sae.get_model_name(othello)\n",
    "\n",
    "# device = RESOURCE_STACK.pop()\n",
    "# print(\"Constructing evaluation dataset\")\n",
    "# data = eval_sae.construct_dataset(othello, custom_functions, dataset_size, device, models_path=model_path)\n",
    "\n",
    "\n",
    "# dataset_size = dataset_size * 2  # x2 to make sure we have enough data for loss_recovered()\n",
    "\n",
    "\n",
    "# # TODO: set `autoencoder_path`\n",
    "# data, ae_bundle, pgn_strings, encoded_inputs = eval_sae.prep_firing_rate_data(\n",
    "#     autoencoder_path, dataset_size, model_path, model_name, data, device, dataset_size, othello\n",
    "# )\n",
    "\n",
    "# dims = torch.tensor([10], device=device)\n",
    "# chess_interp.examine_dimension_chess(ae_bundle, 100, dims)\n",
    "\n",
    "# RESOURCE_STACK.append(device)\n",
    "# del device"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
