{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Upload SAEs from Weights & Biases to HuggingFace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This script uploads trained SAEs from Weights and Biases run artifacts to a HuggingFace model repo. Each run will get its own folder containing the weights, a config JSON and a tensor of sparsity values. The models are first downloaded to the local machine in the `scripts/artifacts` folder, then uploaded to HuggingFace, before the artifacts downloaded in this run are then deleted from `scripts/artifacts` (pre-existing artifacts in this folder will be left unchanged).\n",
    "\n",
    "To run this script uou'll need to:\n",
    "- Create a HuggingFace model repo to upload to (or use an existing one).\n",
    "- Get a HuggingFace [write access token](https://huggingface.co/docs/hub/en/security-tokens) for the repo.\n",
    "- Put these variables in the cell below, along with the W&B project name you want to transfer.\n",
    "\n",
    "Note: only finished runs will be transferred. You can change this or add extra filtering in the cell beginning with the comment \"more filtering on W&B runs can be done here\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Script variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb_project_name = \"YOUR-WANDB-PROJECT\"\n",
    "hf_repo_id = \"YOUR-HF-REPO\"\n",
    "hf_token = \"YOUR-HF-TOKEN\"  # do not upload to github!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## W&B downloads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get runs from W&B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# more filtering on W&B runs can be done here\n",
    "runs = api.runs(wandb_project_name)\n",
    "completed_runs = [run for run in runs if run.state == \"finished\"]\n",
    "sorted_by_layer = sorted(completed_runs, key=lambda x: x.config[\"hook_layer\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### W&B helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_endstr(run):\n",
    "    d_sae = int(run.config[\"d_in\"]) * int(run.config[\"expansion_factor\"])\n",
    "    hook_point = run.config[\"hook_name\"]\n",
    "    return \"_\".join([hook_point, str(d_sae)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_model(artifact, model_endstr):\n",
    "    return artifact.name.split(\":\")[0].endswith(model_endstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sparsity_endstr(run):\n",
    "    model_endstr = get_model_endstr(run)\n",
    "    return model_endstr + \"_log_feature_sparsity\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_sparsity(artifact, sparsity_endstr):\n",
    "    return artifact.name.split(\":\")[0].endswith(sparsity_endstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_artifact(run):\n",
    "    model_endstr = get_model_endstr(run)\n",
    "    for a in run.logged_artifacts():\n",
    "        if is_model(a, model_endstr):\n",
    "            return a\n",
    "        \n",
    "        else:\n",
    "            continue\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sparsity_artifact(run):\n",
    "    sparsity_endstr = get_sparsity_endstr(run)\n",
    "    for a in run.logged_artifacts():\n",
    "        if is_sparsity(a, sparsity_endstr):\n",
    "            return a\n",
    "        \n",
    "        else:\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download_model_and_sparsity(run):\n",
    "    print(f\"Downloading model & sparsity for {run.config['hook_point']}\")\n",
    "    model_artifact = get_model_artifact(run)\n",
    "    model_path = model_artifact.download()\n",
    "    sparsity_artifact = get_sparsity_artifact(run)\n",
    "    sparsity_artifact.download(root=model_path)\n",
    "    return model_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download from W&B (quite slow)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_paths = [download_model_and_sparsity(run) for run in sorted_by_layer]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HF uploads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "api = HfApi()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def upload_to_hf(model_path):\n",
    "    repo_path = model_path.split('/')[-1]\n",
    "    api.upload_folder(\n",
    "    folder_path=model_path,\n",
    "    repo_id=hf_repo_id,\n",
    "    path_in_repo=repo_path,\n",
    "    token=hf_token,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_path in model_paths:\n",
    "    print(f\"Uploading {model_path}\")\n",
    "    upload_to_hf(model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleanup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_path in model_paths:\n",
    "    ## first remove directory contents\n",
    "    for f in os.scandir(model_path):\n",
    "        os.remove(f)\n",
    "\n",
    "    ## next remove the (now empty) dir\n",
    "    os.rmdir(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sae-lens-QXoybXMy-py3.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
