{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nnsight import LanguageModel\n",
    "import torch\n",
    "\n",
    "from dictionary_learning import ActivationBuffer\n",
    "from dictionary_learning.training import trainSAE\n",
    "from circuits.nanogpt_to_hf_transformers import NanogptTokenizer, convert_nanogpt_model\n",
    "from dictionary_learning.utils import hf_dataset_to_generator\n",
    "from dictionary_learning.trainers.standard import StandardTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device(\"cuda\")\n",
    "\n",
    "tokenizer = NanogptTokenizer()\n",
    "model = convert_nanogpt_model(\"lichess_8layers_ckpt_no_optimizer.pt\", torch.device(DEVICE))\n",
    "model = LanguageModel(model, device_map=DEVICE, tokenizer=tokenizer)\n",
    "\n",
    "submodule = model.transformer.h[5].mlp  # layer 1 MLP\n",
    "activation_dim = 512  # output dimension of the MLP\n",
    "dictionary_size = 8 * activation_dim\n",
    "\n",
    "batch_size = 8\n",
    "\n",
    "data = hf_dataset_to_generator(\"redacted/chess_sae_test\", streaming=False)\n",
    "buffer = ActivationBuffer(\n",
    "    data,\n",
    "    model,\n",
    "    submodule,\n",
    "    n_ctxs=512,\n",
    "    ctx_len=256,\n",
    "    refresh_batch_size=4,\n",
    "    io=\"out\",\n",
    "    d_submodule=512,\n",
    "    device=DEVICE,\n",
    "    out_batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dictionary_learning import AutoEncoder\n",
    "\n",
    "ae = AutoEncoder.from_pretrained(\"t1_ae.pt\", device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dictionary_learning.evaluation import evaluate\n",
    "\n",
    "eval_results = evaluate(ae, buffer, device=DEVICE, batch_size=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in eval_results.items():\n",
    "    print(k,v)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "circuits",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
