{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f41e557e-f922-43a8-9a9f-d871cf541028",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to do this before transformer imports\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d850e51f-fb53-45ff-b3b0-b163ef66f148",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from evals.utils import *\n",
    "from models.bigram_model import *\n",
    "from models.mlp_model import *\n",
    "from models.future_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import Dataset\n",
    "from torch import nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b802590f-f27a-4bbb-9703-48798e60d51b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f620a08b-258e-4e8d-bde6-4d8a0191e21d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ed6e48c9-7eb9-4f07-a971-8f9c58a07b1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_llama = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_LLAMA2_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "ds_mistral = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_MISTRAL_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "val = {\n",
    "    'LLAMA2': ds_llama['val'],\n",
    "    'MISTRAL': ds_mistral['val'],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "710a75f5-b512-4c0f-9bf4-961cd7dd3b26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLAMA2\n",
      "/workspace/checkpoints/LLAMA2-NECK-SWEEP_20240117-035543-FBa93_base_model_name-daryl149/llama-2-7b-hf_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-mlp_epoch=00-val_self_loss=3.62.ckpt\n",
      "LLAMA2-CHAT\n",
      "/workspace/checkpoints/LLAMA2-NECK-SWEEP_20240117-234516-0ca77_base_model_name-daryl149/llama-2-7b-chat-hf_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-mlp_epoch=00-val_self_loss=3.28.ckpt\n",
      "LLAMA2-RAND\n",
      "/workspace/checkpoints/LLAMA2-RAND-INIT_20240129-180219-4a19d_dataset_name-msmarco_hidden_idxs-31_hidden_lb-0_neck_cls-mlp_pretrained-0_token_lb-0_epoch=00-val_self_loss=10.33.ckpt\n",
      "MISTRAL\n",
      "/workspace/checkpoints/MISTRAL-NECK-SWEEP_20231231-100744-dDbEa_hidden_idxs-31_hidden_lb-0_token_lb-0_neck_cls-mlp_epoch=00-val_self_loss=4.31.ckpt\n"
     ]
    }
   ],
   "source": [
    "pre = '/workspace/checkpoints/'\n",
    "layer = 31\n",
    "cls = 'mlp'\n",
    "ckpts = {\n",
    "    'LLAMA2': glob(pre + f'LLAMA2-NECK-SWEEP_*daryl149/llama-2-7b-hf_hidden_idxs-{layer}_hidden_lb-0_token_lb-0_neck_cls-{cls}*')[0],\n",
    "    'LLAMA2-CHAT': glob(pre + f'LLAMA2-NECK-SWEEP_*daryl149/llama-2-7b-chat-hf_hidden_idxs-{layer}_hidden_lb-0_token_lb-0_neck_cls-{cls}*')[0],\n",
    "    'LLAMA2-RAND': glob(pre + f'LLAMA2-RAND-INIT*_hidden_idxs-{layer}_hidden_lb-0_neck_cls-{cls}*')[0],\n",
    "    'MISTRAL': glob(pre + f'MISTRAL-NECK-SWEEP*hidden_idxs-{layer}_hidden_lb-0_token_lb-0_neck_cls-{cls}*')[0],\n",
    "}\n",
    "for k, v in ckpts.items():\n",
    "    print(k)\n",
    "    print(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "88149dfa-3076-4c95-9f08-d748b1ee7c7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LLAMA2\n",
      "LLAMA2-CHAT\n",
      "LLAMA2-RAND\n",
      "MISTRAL\n"
     ]
    }
   ],
   "source": [
    "necks = {}\n",
    "for k in ckpts:\n",
    "    print(k)\n",
    "    model = LitFutureModelWithNeck.load_from_checkpoint(ckpts[k], strict=False, pretrained=False)\n",
    "    necks[k] = model.future_neck\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1f8300b8-2389-4b72-a10e-64c7ac3a5e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "# consistent ordering\n",
    "models = ['LLAMA2', 'LLAMA2-CHAT', 'LLAMA2-RAND', 'MISTRAL']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "87dd5899-2b6c-42a1-81aa-4334d6f564e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7b0c044f-bf7d-4f50-9f85-d01717817773",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BASE LLAMA2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1c7417bda98401583f0b924f2f93b3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NECK LLAMA2\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'MLPNeck' object has no attribute 'items'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[18], line 14\u001b[0m\n\u001b[1;32m     12\u001b[0m model\u001b[38;5;241m.\u001b[39mfuture_neck \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mfuture_necks[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdefault\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 14\u001b[0m     neck_losses \u001b[38;5;241m=\u001b[39m [model\u001b[38;5;241m.\u001b[39m_compute_loss(batch)\u001b[38;5;241m.\u001b[39mself_loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m islice(\u001b[38;5;28miter\u001b[39m(loader), \u001b[38;5;241m2\u001b[39m)]\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLOSS\u001b[39m\u001b[38;5;124m'\u001b[39m, np\u001b[38;5;241m.\u001b[39mmean(neck_losses))\n\u001b[1;32m     16\u001b[0m base_losses\u001b[38;5;241m.\u001b[39mappend(np\u001b[38;5;241m.\u001b[39mmean(neck_losses))\n",
      "Cell \u001b[0;32mIn[18], line 14\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     12\u001b[0m model\u001b[38;5;241m.\u001b[39mfuture_neck \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mfuture_necks[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdefault\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m---> 14\u001b[0m     neck_losses \u001b[38;5;241m=\u001b[39m [\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mself_loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m islice(\u001b[38;5;28miter\u001b[39m(loader), \u001b[38;5;241m2\u001b[39m)]\n\u001b[1;32m     15\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mLOSS\u001b[39m\u001b[38;5;124m'\u001b[39m, np\u001b[38;5;241m.\u001b[39mmean(neck_losses))\n\u001b[1;32m     16\u001b[0m base_losses\u001b[38;5;241m.\u001b[39mappend(np\u001b[38;5;241m.\u001b[39mmean(neck_losses))\n",
      "File \u001b[0;32m/workspace/FutureGPT2/src/models/future_model.py:165\u001b[0m, in \u001b[0;36mLitFutureModel._compute_loss\u001b[0;34m(self, batch, do_profile)\u001b[0m\n\u001b[1;32m    156\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;124;03mReturns dict of:\u001b[39;00m\n\u001b[1;32m    158\u001b[0m \u001b[38;5;124;03m    base_loss: D_KL(base, orig) (only if freeze_base=self)\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    162\u001b[0m \u001b[38;5;124;03m    total_loss: base_loss + kappa * self_loss (this is the one used for training)\u001b[39;00m\n\u001b[1;32m    163\u001b[0m \u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[1;32m    164\u001b[0m loss \u001b[38;5;241m=\u001b[39m dotdict()\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdo_profile\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdo_profile\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    166\u001b[0m base_logits \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m    167\u001b[0m loss\u001b[38;5;241m.\u001b[39mbase_loss \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mloss\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/workspace/FutureGPT2/src/models/future_model.py:135\u001b[0m, in \u001b[0;36mLitFutureModel.forward\u001b[0;34m(self, batch, do_profile)\u001b[0m\n\u001b[1;32m    133\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_log_prof(prof, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124min_embed\u001b[39m\u001b[38;5;124m'\u001b[39m, batch_size)\n\u001b[1;32m    134\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m maybe_profile(do_profile) \u001b[38;5;28;01mas\u001b[39;00m prof:\n\u001b[0;32m--> 135\u001b[0m     future_outs \u001b[38;5;241m=\u001b[39m {k: v(hidden_states, tokens) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfuture_neck\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m()}\n\u001b[1;32m    136\u001b[0m \u001b[38;5;66;03m# print_mem('E')\u001b[39;00m\n\u001b[1;32m    137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_log_prof(prof, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mneck\u001b[39m\u001b[38;5;124m'\u001b[39m, batch_size)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1693\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m   1694\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'MLPNeck' object has no attribute 'items'"
     ]
    }
   ],
   "source": [
    "from itertools import islice \n",
    "\n",
    "losses = []\n",
    "for k in models:\n",
    "    print('BASE', k)\n",
    "    model = LitFutureModelWithNeck.load_from_checkpoint(ckpts[k], strict=False)\n",
    "    base_losses = []\n",
    "    loader = DataLoader(val[k.split('-')[0]], batch_size=128)\n",
    "    for kn in models:\n",
    "        print('NECK', kn)\n",
    "        model.future_necks = nn.ModuleDict({'default': necks[kn]})\n",
    "        model.future_neck = model.future_necks['default']\n",
    "        with torch.no_grad():\n",
    "            neck_losses = [model._compute_loss(batch).self_loss.item() for batch in islice(iter(loader), 2)]\n",
    "        print('LOSS', np.mean(neck_losses))\n",
    "        base_losses.append(np.mean(neck_losses))\n",
    "    losses.append(base_losses)\n",
    "    gc.collect()\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eea32405-e07b-497e-ade4-0411b52e2e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "losses"
   ]
  },
  {
   "cell_type": "raw",
   "id": "3db75145-0665-4e85-b1e3-7dd374e79cf2",
   "metadata": {},
   "source": [
    "[[3.624929666519165, 4.150340795516968, 32.17119598388672, 10.323686599731445],\n",
    " [3.5302734375, 3.2802786827087402, 26.60939598083496, 10.336112022399902],\n",
    " [10.039473533630371,\n",
    "  10.013981819152832,\n",
    "  4.227360725402832,\n",
    "  10.367734909057617],\n",
    " [69.1989860534668, 81.19962310791016, 323.8224182128906, 10.382081031799316]]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
