{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# I have been having filename length issues with HuggingFace download because they use a 80 character hash for intermediate filenames when downloading\n",
    "# Very sad that we have to do this, but we need to shorten the filenames to avoid this issue\n",
    "\n",
    "# Keywords to remove from filenames\n",
    "KEYWORDS = [\n",
    "    \"custom_sae_\",\n",
    "    \"AprilUpdate\",\n",
    "    \"EleutherAI_\",\n",
    "    \"google_\",\n",
    "    \"Trainer\",\n",
    "    \"ctx1024\",\n",
    "]\n",
    "\n",
    "# Set to True to actually rename files, False to just preview changes\n",
    "PERFORM_RENAME = True  # <-- Change this to True when ready to rename\n",
    "\n",
    "\n",
    "def shorten_filenames(start_path=\".\"):\n",
    "    count = 0\n",
    "\n",
    "    max_begin_length = 0\n",
    "    max_end_length = 0\n",
    "\n",
    "    for root, dirs, files in os.walk(start_path):\n",
    "        for filename in files:\n",
    "            # Only process JSON files\n",
    "            if filename.endswith(\".json\"):\n",
    "                new_filename = filename\n",
    "                was_modified = False\n",
    "\n",
    "                # Remove each keyword from just the filename\n",
    "                for keyword in KEYWORDS:\n",
    "                    if keyword in new_filename:\n",
    "                        new_filename = new_filename.replace(keyword, \"\")\n",
    "                        was_modified = True\n",
    "\n",
    "                if was_modified:\n",
    "                    old_path = os.path.join(root, filename)\n",
    "                    new_path = os.path.join(\n",
    "                        root, new_filename\n",
    "                    )  # Only the filename is modified, root path stays the same\n",
    "\n",
    "                    if PERFORM_RENAME:\n",
    "                        try:\n",
    "                            os.rename(old_path, new_path)\n",
    "                            # print(f\"In directory: {root}\")\n",
    "                            # print(f\"Renamed: {filename}\\n     -> {new_filename}\")\n",
    "                            # print(f\"Old length: {len(filename)}, New length: {len(new_filename)}\\n\")\n",
    "                            max_begin_length = max(max_begin_length, len(filename))\n",
    "                            max_end_length = max(max_end_length, len(new_filename))\n",
    "                        except Exception as e:\n",
    "                            print(f\"Error renaming {filename}: {e}\")\n",
    "                    else:\n",
    "                        print(f\"In directory: {root}\")\n",
    "                        print(f\"Would rename: {filename}\\n         -> {new_filename}\")\n",
    "                        print(\n",
    "                            f\"Old length: {len(filename)}, New length: {len(new_filename)}\\n\"\n",
    "                        )\n",
    "\n",
    "                        max_begin_length = max(max_begin_length, len(filename))\n",
    "                        max_end_length = max(max_end_length, len(new_filename))\n",
    "\n",
    "                    count += 1\n",
    "    print(f\"Max filename length before: {max_begin_length}, after: {max_end_length}\")\n",
    "    print(\n",
    "        f\"\\nTotal files {'renamed' if PERFORM_RENAME else 'that would be renamed'}: {count}\"\n",
    "    )\n",
    "\n",
    "\n",
    "# Example usage:\n",
    "path = \".\"  # Adjust this path as needed\n",
    "shorten_filenames(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "\n",
    "def organize_files_in_folders(folder_paths):\n",
    "    # Define the prefixes we're looking for\n",
    "    prefixes = [\n",
    "        \"saebench_pythia-160m-deduped_width-2pow14_date-0108\",\n",
    "        # 'sae_bench_gemma-2-2b_vanilla_width-2pow16_date-1109'\n",
    "    ]\n",
    "\n",
    "    for folder in folder_paths:\n",
    "        if not os.path.exists(folder):\n",
    "            print(f\"Folder {folder} does not exist. Skipping.\")\n",
    "            continue\n",
    "\n",
    "        # Get all json files in the current directory\n",
    "        files = [f for f in os.listdir(\".\") if f.endswith(\".json\")]\n",
    "        if len(files) == 0:\n",
    "            print(f\"No json files found in {folder}. Skipping.\")\n",
    "            continue\n",
    "\n",
    "        num_files_with_prefix = 0\n",
    "\n",
    "        for prefix in prefixes:\n",
    "            num_files_with_prefix += len([f for f in files if f.startswith(prefix)])\n",
    "\n",
    "        if num_files_with_prefix == 0:\n",
    "            raise ValueError(f\"No files with prefix {prefixes} found in {folder}.\")\n",
    "\n",
    "        os.chdir(folder)  # Change to the target directory\n",
    "        print(f\"Processing folder: {folder}\")\n",
    "\n",
    "        # Create folders if they don't exist\n",
    "        for prefix in prefixes:\n",
    "            if not os.path.exists(prefix):\n",
    "                os.makedirs(prefix)\n",
    "\n",
    "        # Move files to appropriate folders\n",
    "        for file in files:\n",
    "            for prefix in prefixes:\n",
    "                if file.startswith(prefix):\n",
    "                    shutil.move(file, os.path.join(prefix, file))\n",
    "                    print(f\"Moved {file} to {prefix}/\")\n",
    "                    break\n",
    "\n",
    "        os.chdir(\"..\")  # Return to the original directory\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # List of folders to process\n",
    "    folder_paths = [\n",
    "        \"absorption\",\n",
    "        \"autointerp\",\n",
    "        \"core\",\n",
    "        \"scr\",\n",
    "        \"sparse_probing\",\n",
    "        \"tpp\",\n",
    "        \"unlearning\",\n",
    "    ]\n",
    "\n",
    "    organize_files_in_folders(folder_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def copy_autointerp():\n",
    "    source = \"autointerp\"\n",
    "    destination = \"autointerp_with_generations\"\n",
    "\n",
    "    if not os.path.exists(source):\n",
    "        print(f\"Source folder '{source}' does not exist. Aborting.\")\n",
    "        return\n",
    "\n",
    "    if os.path.exists(destination):\n",
    "        raise Exception(f\"Destination folder '{destination}' already exists. Aborting.\")\n",
    "\n",
    "    try:\n",
    "        shutil.copytree(source, destination)\n",
    "        print(f\"Copied '{source}' to '{destination}'.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Failed to copy '{source}' to '{destination}': {e}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    copy_autointerp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# purpose: remove the llm generations, which is over 99% of the file size\n",
    "\n",
    "import json\n",
    "\n",
    "\n",
    "def process_json_files(directory):\n",
    "    \"\"\"Recursively process all JSON files in directory and its subdirectories\"\"\"\n",
    "    count = 0\n",
    "    for root, _, files in os.walk(directory):\n",
    "        for file in files:\n",
    "            if file.endswith(\".json\"):\n",
    "                filepath = os.path.join(root, file)\n",
    "                try:\n",
    "                    # Read the JSON file\n",
    "                    with open(filepath) as f:\n",
    "                        data = json.load(f)\n",
    "\n",
    "                    # Remove the key if it exists\n",
    "                    if \"eval_result_unstructured\" in data:\n",
    "                        del data[\"eval_result_unstructured\"]\n",
    "                        count += 1\n",
    "\n",
    "                    # Write back the modified data\n",
    "                    with open(filepath, \"w\") as f:\n",
    "                        json.dump(data, f)\n",
    "\n",
    "                    # print(f\"Processed: {filepath}\")\n",
    "\n",
    "                except Exception as e:\n",
    "                    print(f\"Error processing {filepath}: {str(e)}\")\n",
    "\n",
    "    return count\n",
    "\n",
    "\n",
    "# Process files starting from the current directory\n",
    "starting_dir = \"autointerp\"\n",
    "files_modified = process_json_files(starting_dir)\n",
    "print(f\"\\nCompleted! Modified {files_modified} files.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def copy_autointerp():\n",
    "    source = \"core\"\n",
    "    destination = \"core_with_feature_statistics\"\n",
    "\n",
    "    if not os.path.exists(source):\n",
    "        print(f\"Source folder '{source}' does not exist. Aborting.\")\n",
    "        return\n",
    "\n",
    "    if os.path.exists(destination):\n",
    "        raise Exception(f\"Destination folder '{destination}' already exists. Aborting.\")\n",
    "\n",
    "    try:\n",
    "        shutil.copytree(source, destination)\n",
    "        print(f\"Copied '{source}' to '{destination}'.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Failed to copy '{source}' to '{destination}': {e}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    copy_autointerp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# purpose: remove the feature statistics, which is over 99% of the file size\n",
    "\n",
    "\n",
    "\n",
    "def process_json_files(directory):\n",
    "    \"\"\"Recursively process all JSON files in directory and its subdirectories\"\"\"\n",
    "    count = 0\n",
    "    for root, _, files in os.walk(directory):\n",
    "        for file in files:\n",
    "            if file.endswith(\".json\"):\n",
    "                filepath = os.path.join(root, file)\n",
    "                try:\n",
    "                    # Read the JSON file\n",
    "                    with open(filepath) as f:\n",
    "                        data = json.load(f)\n",
    "\n",
    "                    # Remove the key if it exists\n",
    "                    if \"eval_result_details\" in data:\n",
    "                        del data[\"eval_result_details\"]\n",
    "                        count += 1\n",
    "\n",
    "                    # Write back the modified data\n",
    "                    with open(filepath, \"w\") as f:\n",
    "                        json.dump(data, f)\n",
    "\n",
    "                    # print(f\"Processed: {filepath}\")\n",
    "\n",
    "                except Exception as e:\n",
    "                    print(f\"Error processing {filepath}: {str(e)}\")\n",
    "\n",
    "    return count\n",
    "\n",
    "\n",
    "# Process files starting from the current directory\n",
    "starting_dir = \"core\"\n",
    "files_modified = process_json_files(starting_dir)\n",
    "print(f\"\\nCompleted! Modified {files_modified} files.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "local_dir = \".\"\n",
    "folders = [\n",
    "    f\"{local_dir}/absorption\",\n",
    "    f\"{local_dir}/autointerp\",\n",
    "    f\"{local_dir}/core\",\n",
    "    f\"{local_dir}/scr\",\n",
    "    f\"{local_dir}/sparse_probing\",\n",
    "    f\"{local_dir}/tpp\",\n",
    "    f\"{local_dir}/unlearning\",\n",
    "    f\"{local_dir}/autointerp_with_generations\",\n",
    "    f\"{local_dir}/core_with_feature_statistics\",\n",
    "]\n",
    "\n",
    "\n",
    "def get_sae_bench_train_tokens(filename) -> int:\n",
    "    \"\"\"\n",
    "    Calculate the number of training tokens based on the sae_release and sae_id.\n",
    "    \"\"\"\n",
    "    if \"saebench\" not in filename:\n",
    "        raise ValueError(\"This function is only for SAE Bench releases\")\n",
    "\n",
    "    batch_size = 2048\n",
    "\n",
    "    if \"step\" not in filename:\n",
    "        steps = 244140\n",
    "        return steps * batch_size\n",
    "    else:\n",
    "        match = re.search(r\"step_(\\d+)\", filename)\n",
    "        if match:\n",
    "            step = int(match.group(1))\n",
    "            return step * batch_size\n",
    "        else:\n",
    "            raise ValueError(\"No step match found\")\n",
    "\n",
    "\n",
    "def process_file(filename: str):\n",
    "    \"\"\"\n",
    "    Process a single file: load JSON, add training_tokens to sae_cfg_dict, and save it.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        with open(filename) as f:\n",
    "            eval_results = json.load(f)\n",
    "\n",
    "        if \"sae_cfg_dict\" not in eval_results:\n",
    "            raise KeyError(\"sae_cfg_dict not found in the JSON file\")\n",
    "\n",
    "        eval_results[\"sae_cfg_dict\"][\"training_tokens\"] = get_sae_bench_train_tokens(\n",
    "            filename\n",
    "        )\n",
    "\n",
    "        with open(filename, \"w\") as f:\n",
    "            json.dump(eval_results, f, indent=4)\n",
    "\n",
    "        return True\n",
    "    except (json.JSONDecodeError, KeyError, ValueError) as e:\n",
    "        print(f\"Error processing file {filename}: {e}\")\n",
    "        return False\n",
    "\n",
    "\n",
    "def main():\n",
    "    total = 0\n",
    "    total_updated = 0\n",
    "\n",
    "    for folder in tqdm(folders, desc=\"Processing folders\"):\n",
    "        if not os.path.exists(folder):\n",
    "            print(f\"Folder {folder} does not exist. Skipping.\")\n",
    "            continue\n",
    "\n",
    "        for root, _, files in os.walk(folder):\n",
    "            for file in files:\n",
    "                if file.endswith(\".json\"):\n",
    "                    total += 1\n",
    "                    filepath = os.path.join(root, file)\n",
    "                    if process_file(filepath):\n",
    "                        total_updated += 1\n",
    "\n",
    "    print(f\"Total files: {total}, Total updated: {total_updated}\")\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raise Exception(\"Stop here\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "\n",
    "api = HfApi()\n",
    "\n",
    "api.upload_large_folder(\n",
    "    folder_path=\".\",\n",
    "    # path_in_repo=\"\",\n",
    "    repo_id=\"adamkarvonen/new_sae_bench_results\",\n",
    "    repo_type=\"dataset\",\n",
    "    allow_patterns=\"*eval_results.json\",\n",
    "    ignore_patterns=[\".DS_Store\"],\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
