{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from olmo.model import HFMixinOLMo\n",
    "\n",
    "from huggingface_hub import snapshot_download"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed244fab0ac94440875c4e5bcfe5e65e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28c3829e40204151ad7fe4c8be05086f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(…)rcoder_N1.1e08_D3.0e08_C2.0e17/README.md:   0%|          | 0.00/320 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "928e1d85b4774e36b9b541fda5f30d20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "(…)oder_N1.1e08_D3.0e08_C2.0e17/config.json:   0%|          | 0.00/958 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "129a9e18cfb444c4b58e403536aced57",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading weights from local directory\n"
     ]
    }
   ],
   "source": [
    "tmp_dir = \"/n/holylabs/LABS/sham_lab/Everyone/tmp_download\"\n",
    "model_name = \"L2L_starcoder_N1.1e08_D3.0e08_C2.0e17\"\n",
    "\n",
    "snapshot_download(\n",
    "    repo_id=\"davidbrandfonbrener/loss-to-loss\", \n",
    "    allow_patterns=f\"{model_name}/*\", \n",
    "    local_dir=tmp_dir,\n",
    ")\n",
    "\n",
    "model = HFMixinOLMo.from_pretrained(f\"{tmp_dir}/{model_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HFMixinOLMo(\n",
       "  (transformer): ModuleDict(\n",
       "    (wte): Embedding(32000, 704)\n",
       "    (emb_drop): Dropout(p=0.0, inplace=False)\n",
       "    (ln_f): LayerNorm()\n",
       "    (blocks): ModuleList(\n",
       "      (0-10): 11 x OLMoSequentialBlock(\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "        (k_norm): LayerNorm()\n",
       "        (q_norm): LayerNorm()\n",
       "        (act): GELU(approximate='none')\n",
       "        (attn_out): Linear(in_features=704, out_features=704, bias=False)\n",
       "        (ff_out): Linear(in_features=2816, out_features=704, bias=False)\n",
       "        (rotary_emb): RotaryEmbedding()\n",
       "        (attn_norm): LayerNorm()\n",
       "        (ff_norm): LayerNorm()\n",
       "        (att_proj): Linear(in_features=704, out_features=2112, bias=False)\n",
       "        (ff_proj): Linear(in_features=704, out_features=2816, bias=False)\n",
       "      )\n",
       "    )\n",
       "    (ff_out_last): Linear(in_features=704, out_features=32000, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Local model loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "local_dir = '/n/holylabs/LABS/sham_lab/Everyone/loss-to-loss/L2L_fineweb-edu-100b_N9.0e07_D3.9e09_C2.1e18'\n",
    "\n",
    "config = yaml.safe_load(open(f\"{local_dir}/config.yaml\"))\n",
    "config.pop(\"evaluators\")\n",
    "config.pop(\"wandb\")\n",
    "config[\"model\"][\"init_device\"] = \"cpu\"\n",
    "\n",
    "model = HFMixinOLMo(config[\"model\"])\n",
    "\n",
    "model_path = f\"{local_dir}/model.pt\"\n",
    "model.load_state_dict(torch.load(model_path, map_location='cpu'), assign=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "loss-to-loss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
