{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch as t\n",
    "import dictionary_learning.dictionary_learning.utils as utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_name = \"Qwen/Qwen2.5-Coder-7B-Instruct\"\n",
    "model_name = \"google/gemma-2-2b\"\n",
    "model_name = \"EleutherAI/pythia-70m-deduped\"\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=t.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(model.model.layers[1])\n",
    "print(model.name_or_path)\n",
    "print(model.config.architectures[0])\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncate_model(model: AutoModelForCausalLM, layer: int):\n",
    "    \"\"\"From tilde-research/activault\n",
    "    https://github.com/tilde-research/activault/blob/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/pipeline/setup.py#L74\n",
    "    This provides significant memory savings by deleting all layers that aren't needed for the given layer.\n",
    "    You should probably test this before using it\"\"\"\n",
    "    import gc\n",
    "\n",
    "\n",
    "    total_params_before = sum(p.numel() for p in model.parameters())\n",
    "    print(f\"Model parameters before truncation: {total_params_before:,}\")\n",
    "\n",
    "    if (\n",
    "        model.config.architectures[0] == \"Qwen2ForCausalLM\"\n",
    "        or model.config.architectures[0] == \"Gemma2ForCausalLM\"\n",
    "    ):\n",
    "\n",
    "        removed_layers = model.model.layers[layer + 1 :]\n",
    "\n",
    "        model.model.layers = model.model.layers[: layer + 1]\n",
    "\n",
    "        del removed_layers\n",
    "        del model.lm_head\n",
    "\n",
    "        model.lm_head = t.nn.Identity()\n",
    "    \n",
    "    elif model.config.architectures[0] == \"GPTNeoXForCausalLM\":\n",
    "\n",
    "        removed_layers = model.gpt_neox.layers[layer + 1 :]\n",
    "\n",
    "        model.gpt_neox.layers = model.gpt_neox.layers[: layer + 1]\n",
    "\n",
    "        del removed_layers\n",
    "        del model.embed_out\n",
    "\n",
    "        model.embed_out = t.nn.Identity()\n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Please add truncation for model {model.name_or_path}\")\n",
    "    \n",
    "\n",
    "\n",
    "    total_params_after = sum(p.numel() for p in model.parameters())\n",
    "    print(f\"Model parameters after truncation: {total_params_after:,}\")\n",
    "\n",
    "    gc.collect()\n",
    "    t.cuda.empty_cache()\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "print(t.cuda.memory_allocated() / 1e9, \"GB actual\")\n",
    "print(t.cuda.memory_reserved()  / 1e9, \"GB in cache\")\n",
    "\n",
    "model = truncate_model(model, 5)\n",
    "\n",
    "print(t.cuda.memory_allocated() / 1e9, \"GB actual\")\n",
    "print(t.cuda.memory_reserved()  / 1e9, \"GB in cache\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
