{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd033590-b90d-4fb3-8c71-6bc75d8e9610",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer\n",
    "import torch\n",
    "import os\n",
    "import numpy as np\n",
    "from torch.utils.flop_counter import FlopCounterMode\n",
    "from sae_lens import SAE\n",
    "from stitching.sae_utils import BaseSAE, topk_activation\n",
    "import functools\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d911f7f0-ad26-483d-80f5-1a54b1163b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "device='cuda'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a652685-06b5-4b73-83f8-6e02cec8a567",
   "metadata": {},
   "source": [
    "# SAE Flop count\n",
    "\n",
    "Main idea: estimate for one iteration, then just propagate from then onwards.\n",
    "\n",
    "You can use dummy activation data.\n",
    "\n",
    "with 32k features and 100k training steps\n",
    "\n",
    "70m - '6.871948e+16'\n",
    "\n",
    "160m - '1.030792e+17'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a80e7e47-bbca-4465-af21-137052bceafd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dim = 1024#768#512\n",
    "batch_size = 2048#4096\n",
    "feature_dim = 65536"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53f5f973-7e89-4fae-9c98-68280129154e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_data = torch.randn((batch_size, model_dim), device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cfba087-e15c-416f-956b-f831c3219145",
   "metadata": {},
   "outputs": [],
   "source": [
    "activation_fn = functools.partial(topk_activation, k=32) # torch.functional.relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70832f2e-b591-4f7b-9c85-fe538faa4632",
   "metadata": {},
   "outputs": [],
   "source": [
    "skeleton = BaseSAE(\n",
    "    torch.randn((model_dim, feature_dim), device=device, requires_grad=True),\n",
    "    torch.randn((feature_dim, model_dim), device=device, requires_grad=True),\n",
    "    torch.randn(feature_dim, device=device, requires_grad=True),\n",
    "    torch.randn(model_dim, device=device, requires_grad=True),\n",
    "    activation_fn,\n",
    "    apply_b_dec=False,\n",
    "    requires_grad=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7558d06-3b22-4510-ba2e-c37dd2c3bb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(skeleton.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8041a00-f823-4c21-9266-83d731a41927",
   "metadata": {},
   "outputs": [],
   "source": [
    "flop_counter = FlopCounterMode(display=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6082468-67c6-4470-b75d-1b5af932cd4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with flop_counter:\n",
    "    # forward pass\n",
    "    acts = skeleton.encode(dummy_data)\n",
    "    #l1 = torch.abs(acts).sum()\n",
    "    mse = (skeleton.decode(acts) - dummy_data).pow(2).sum()\n",
    "    loss = mse #+ l1\n",
    "    loss.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa6db3bc-9ebe-4d70-8506-4a3cbc6e0e1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"{flop_counter.get_total_flops():e}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac3a318d-cae2-4fcd-a6e1-31f674e2be8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_iters = 120_000\n",
    "f\"{flop_counter.get_total_flops()*training_iters:e}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cf5d448-2ce7-4014-8bd8-97dd8fa7ffe9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "057ef046-b2e2-4322-9eec-14bd4d75884e",
   "metadata": {},
   "source": [
    "# Stitch flop count\n",
    "35k steps\n",
    "70m to 160m: '1.409286e+15'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec41093c-e2fb-4cff-a51c-908fa0e2def5",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "n_seq = 512\n",
    "modelA_dim = 768#512\n",
    "modelB_dim = 1024#768"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2a696d-4bec-4794-922f-f4af3fe347d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_data_A = torch.randn((batch_size, n_seq, modelA_dim), device=device)\n",
    "dummy_data_B = torch.randn((batch_size, n_seq, modelB_dim), device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bc10b4f-bab7-48c9-a71a-d83fd122bafb",
   "metadata": {},
   "outputs": [],
   "source": [
    "P_up = torch.nn.Linear(modelA_dim, modelB_dim, bias=True, device=device)\n",
    "P_down = torch.nn.Linear(modelB_dim, modelA_dim, bias=True, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27378199-7e09-43ec-8750-2e2ee35c00bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = torch.optim.Adam(itertools.chain(P_up.parameters(), P_down.parameters()), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43b391fd-2f39-4b6d-9365-d0f550cd11b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "flop_counter = FlopCounterMode(display=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "426563e8-c82a-4b89-9c26-205cade4e4a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "with flop_counter:\n",
    "    pred_B = P_up(dummy_data_A)\n",
    "    pred_A = P_down(dummy_data_B)\n",
    "    inv_pred_A = P_down(pred_B)\n",
    "    inv_pred_B = P_up(pred_A)\n",
    "    \n",
    "    loss = (pred_B - dummy_data_B).pow(2).sum() + (pred_A - dummy_data_A).pow(2).sum() + (inv_pred_A - dummy_data_A).pow(2).sum() + (inv_pred_B - dummy_data_B).pow(2).sum()\n",
    "    loss.backward()\n",
    "    optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41079064-e2eb-4c24-9f63-239b77e6ec4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"{flop_counter.get_total_flops():e}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d0e75db-1f3c-4ea3-83f1-f8a855ea2d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_iters = 36_000#18_000\n",
    "f\"{flop_counter.get_total_flops()*training_iters:e}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53f862e5-8dff-4b66-8e76-5ac097f3feab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
