{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd5ce645-efa0-4cf4-be81-d70fc11b62b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to do this before transformer imports\n",
    "# import os\n",
    "# os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "\n",
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "from models.myopic_model import *\n",
    "from data.utils import get_tokenizer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "from itertools import islice\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import Dataset\n",
    "from torch import nn\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger\n",
    "\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import gc\n",
    "from glob import glob\n",
    "import numpy as np\n",
    "import copy\n",
    "import wandb\n",
    "from tqdm import tqdm\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f477a7f-b45a-4762-bd68-a59f7b23833d",
   "metadata": {},
   "outputs": [],
   "source": [
    "NAME = 'GPT2-MYOPIC-H-FROMORIG'\n",
    "PROJ = 'XXXXXXXX_FUTURE_GPT2'\n",
    "model_name = 'gpt2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1630f1e7-4d0e-43ae-9610-23e06a79a525",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/XXXXXXXX/.netrc\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login(key='XXXXXXXX', relogin=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d0c79a54-15f2-4e7b-97d4-c17a2b7760a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.get_device_capability()[0] >= 8:\n",
    "    torch.set_float32_matmul_precision('high')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53cbc9a2-362f-4360-b0bb-4175e06cc237",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/train').with_format('torch')\n",
    "val = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/val').with_format('torch')\n",
    "train_loader = DataLoader(train, batch_size=256)\n",
    "val_loader = DataLoader(val, batch_size=256)\n",
    "#loaders = {\n",
    "#    split: DataLoader(dataset[split], batch_size=128)\n",
    "#    for split in ['train', 'val', 'test']\n",
    "#}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ed65c2cf-e92d-451c-8e4a-1e922ecfb227",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb_logger = WandbLogger(\n",
    "    name=NAME,\n",
    "    project=PROJ,\n",
    "    log_model=False,   # Only save checkpoints locally\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d56dc71a-a8a4-4729-8146-7ecdd677a127",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "lr_callback = LearningRateMonitor()\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=\"/workspace/checkpoints\",\n",
    "    filename=NAME + \"_{val_myopic_loss:.2f}\",\n",
    "    every_n_epochs=1,\n",
    "    save_top_k=1,\n",
    "    monitor='val_myopic_loss',\n",
    "    mode='min',\n",
    ")\n",
    "early_stop_callback = EarlyStopping(\n",
    "    monitor='val_myopic_loss',\n",
    "    divergence_threshold=15,\n",
    "    min_delta=0.00,\n",
    "    patience=100000,\n",
    "    verbose=False,\n",
    "    mode='min',\n",
    ")\n",
    "trainer = L.Trainer(\n",
    "    fast_dev_run=False,\n",
    "    logger=wandb_logger,\n",
    "    val_check_interval=.1,\n",
    "    callbacks=[checkpoint_callback, lr_callback, early_stop_callback],\n",
    "    max_epochs=1,\n",
    "    enable_progress_bar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9b75db6-9e65-45b9-aeab-114153ec0d3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_name)\n",
    "orig_model = AutoModelForCausalLM.from_config(config=config)\n",
    "# orig_state = torch.load('/workspace/checkpoints/GPT2-MYOPIC-NO-ORIG_val_myopic_loss=5.34.ckpt')\n",
    "orig_state = torch.load('/workspace/checkpoints/GPT2-MSMARCO-COSINE_global_step=9099.0_val_loss=3.28.ckpt')\n",
    "orig_model.load_state_dict(\n",
    "    {'.'.join(k.split('.')[1:]): v for k, v in orig_state['state_dict'].items()} #if 'myopic_model' in k}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f7d6772e-9a5d-4d73-95b3-c6d8000fa293",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = AutoConfig.from_pretrained(model_name)\n",
    "myopic_model = AutoModelForCausalLM.from_config(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1bdddd4c-b4df-4cd2-ba8a-faadc6983b81",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.5 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240330_013127-xfbv3mf3</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/xfbv3mf3' target=\"_blank\">GPT2-MYOPIC-H-FROMORIG</a></strong> to <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/xfbv3mf3' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/xfbv3mf3</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n"
     ]
    }
   ],
   "source": [
    "model = LitMyopicModel(\n",
    "    myopic_model=myopic_model,\n",
    "    orig_model=orig_model,    # set to None (default) for cutgrad training [use own detached hidden state or kv]\n",
    "    #orig_model=myopic_model,\n",
    "    #orig_model=None,\n",
    "    loss_type = 'myopic_loss',\n",
    "    from_kv=False,\n",
    ")\n",
    "wandb_logger.watch(model.myopic_model, log='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8b2bff-4f89-43e9-83d1-b9d99eea33a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\n",
      "  | Name         | Type            | Params\n",
      "-------------------------------------------------\n",
      "0 | myopic_model | GPT2LMHeadModel | 124 M \n",
      "1 | orig_model   | GPT2LMHeadModel | 124 M \n",
      "-------------------------------------------------\n",
      "124 M     Trainable params\n",
      "124 M     Non-trainable params\n",
      "248 M     Total params\n",
      "995.518   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NUM TRAINING STEPS 18201\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2159af57c90f4f7fb19ce4ddd4991293",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train_loader,\n",
    "    val_dataloaders=val_loader,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5096758-0aba-40f5-b9c1-34320f6288c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.is_grad_enabled()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74990bf3-f0ba-4760-bb4a-0732fe1d0702",
   "metadata": {},
   "outputs": [],
   "source": [
    "myopic_model.transformer.h[0].attn.c_attn?"
   ]
  },
  {
   "cell_type": "raw",
   "id": "b8b8185b-bea7-4539-9da2-e46659053dca",
   "metadata": {},
   "source": [
    "past_key[0,0,0,:10]\n",
    "\n",
    "tensor([ 0.2033,  0.4537,  0.5270, -0.6482, -0.6826,  0.5862,  1.6485,  0.1103,\n",
    "         0.4344,  0.3957], device='cuda:0')\n",
    "\n",
    " key[0,0,0,:10]\n",
    "\n",
    "tensor([ 0.4089, -0.1765,  0.3309, -0.8672, -0.9072,  0.0784,  0.6817, -0.0922,\n",
    "         0.2347,  0.7114], device='cuda:0')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d378287-0e94-45d1-a772-a61a5c88b288",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
