{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbf9da59-adc3-41e7-81df-af854ec8c4e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e3751427-cb5b-4cab-908f-55b914114d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.get_device_capability()[0] >= 8:\n",
    "    torch.set_float32_matmul_precision('high')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "55659a78-bcbf-4978-a901-56790bd447ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/XXXXXXXX/.netrc\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login(key='XXXXXXXX', relogin=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c1462d7e-2b6c-44bd-a2f3-f4727b145036",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitGPTModel(L.LightningModule):\n",
    "    '''\n",
    "    Train only position encodings.\n",
    "    '''\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_name='gpt2',\n",
    "        lr=6e-4,\n",
    "        num_warmup_steps=1000,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        args = vars()\n",
    "        for param in list(signature(LitGPTModel.__init__).parameters)[1:]:\n",
    "            setattr(self, param, args[param])\n",
    "        config = AutoConfig.from_pretrained(model_name)\n",
    "        self.model = AutoModelForCausalLM.from_config(config=config)\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "    def forward(self, batch):\n",
    "        # TODO: REMOVE THIS!\n",
    "        batch['input_ids'][:,9] = 21219 # POTATO\n",
    "        return self.model.forward(\n",
    "            input_ids=batch['input_ids'],\n",
    "            attention_mask=batch['attention_mask'],\n",
    "            labels=batch['input_ids'],\n",
    "            use_cache=True,\n",
    "        )\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        loss = self.forward(batch).loss\n",
    "        self.log('train_loss', loss.item(), on_step=True)\n",
    "        self.log('global_step', self.trainer.global_step)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss = self.forward(batch).loss\n",
    "        self.log('val_loss', loss.item())\n",
    "        return loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        loss = self.forward(batch).loss\n",
    "        self.log('test_loss', loss.item())\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = optim.Adam(\n",
    "            params=self.model.parameters(),\n",
    "            lr=self.lr,\n",
    "        )\n",
    "        scheduler = transformers.get_cosine_schedule_with_warmup(\n",
    "            optimizer=optimizer,\n",
    "            num_warmup_steps=self.num_warmup_steps,\n",
    "            num_training_steps=9200 #1 epoch #self.trainer.estimated_stepping_batches,\n",
    "        )\n",
    "        #print('NUM TRAINING STEPS', self.trainer.estimated_stepping_batches)\n",
    "        # HF's schedulers are on 'step' interval (I think)\n",
    "        return (\n",
    "            [optimizer],\n",
    "            [{\"scheduler\": scheduler, \"interval\": \"step\"}]\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "05f17a5a-fe37-44f4-8eff-304aa478867d",
   "metadata": {},
   "outputs": [],
   "source": [
    "NAME = 'GPT2-MSMARCO-POTATO'\n",
    "PROJ = 'XXXXXXXX_FUTURE_GPT2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f2fde467-b571-4cc0-8eda-c6313ddc4cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/train').with_format('torch')\n",
    "val = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/val').with_format('torch')\n",
    "train_loader = DataLoader(train, batch_size=512, num_workers=96)\n",
    "val_loader = DataLoader(val, batch_size=512, num_workers=96)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8080f468-9a8e-4aec-ab92-7c19afc9858f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4659264"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b624a4b6-91d3-403c-b0f6-5bd1ffe8ac37",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb_logger = WandbLogger(\n",
    "    name=NAME,\n",
    "    project=PROJ,\n",
    "    log_model=False,   # Only save checkpoints locally\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4e2009aa-4900-480c-899d-e3cd1bedacbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "lr_monitor = LearningRateMonitor()\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=\"/workspace/checkpoints\",\n",
    "    filename=NAME + \"_{global_step}_{val_loss:.2f}\",\n",
    "    every_n_epochs=1,\n",
    "    save_top_k=1,\n",
    "    monitor='val_loss',\n",
    "    mode='min',\n",
    ")\n",
    "early_stop_callback = EarlyStopping(\n",
    "    monitor='val_loss',\n",
    "    divergence_threshold=15,\n",
    "    min_delta=0.00,\n",
    "    patience=10,\n",
    "    verbose=False,\n",
    "    mode='min',\n",
    ")\n",
    "trainer = L.Trainer(\n",
    "    fast_dev_run=False,\n",
    "    logger=wandb_logger,\n",
    "    val_check_interval=.1,\n",
    "    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],\n",
    "    max_epochs=1,\n",
    "    enable_progress_bar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "56c2c6e2-7175-4a63-b6e1-ee5fb64489c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.4 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240312_161057-p4stpqt8</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/p4stpqt8' target=\"_blank\">GPT2-MSMARCO-POTATO</a></strong> to <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/p4stpqt8' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/p4stpqt8</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n"
     ]
    }
   ],
   "source": [
    "model = LitGPTModel()\n",
    "wandb_logger.watch(model.model.transformer.wpe, log='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "340c982a-8565-4248-8ca3-781f96c00506",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type            | Params\n",
      "------------------------------------------\n",
      "0 | model | GPT2LMHeadModel | 124 M \n",
      "------------------------------------------\n",
      "124 M     Trainable params\n",
      "0         Non-trainable params\n",
      "124 M     Total params\n",
      "497.759   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "970f9eec82a941538d3d0679ed721cdf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train_loader,#loaders['train'],\n",
    "    val_dataloaders=val_loader,#loaders['val']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "635ff2b4-e128-4a17-9ab1-240de8f62bb2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3080e077-2c57-4ca9-b0c1-eef7fe2ee294",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
