{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bbf9da59-adc3-41e7-81df-af854ec8c4e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e3751427-cb5b-4cab-908f-55b914114d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.get_device_capability()[0] >= 8:\n",
    "    torch.set_float32_matmul_precision('medium')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "55659a78-bcbf-4978-a901-56790bd447ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/XXXXXXXX/.netrc\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login(key='XXXXXXXX', relogin=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c1462d7e-2b6c-44bd-a2f3-f4727b145036",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitUnbalancedModel(L.LightningModule):\n",
    "    '''\n",
    "    Train only position encodings.\n",
    "    '''\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_name='gpt2',\n",
    "        lr=6e-4,\n",
    "        num_warmup_steps=1000,\n",
    "        alpha=0.01,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        args = vars()\n",
    "        for param in list(signature(LitUnbalancedModel.__init__).parameters)[1:]:\n",
    "            setattr(self, param, args[param])\n",
    "        config = AutoConfig.from_pretrained(model_name)\n",
    "        self.model = AutoModelForCausalLM.from_config(config=config)\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "    def forward(self, batch):\n",
    "        out = self.model.forward(\n",
    "            input_ids=batch['input_ids'],\n",
    "            attention_mask=batch['attention_mask'],\n",
    "            use_cache=True,\n",
    "        )\n",
    "        ce = nn.CrossEntropyLoss()\n",
    "        out['even_loss'] = ce(\n",
    "            out.logits.transpose(1, 2)[...,::2],\n",
    "            batch['input_ids'][...,1::2]\n",
    "        )\n",
    "        out['odd_loss'] = ce(\n",
    "            out.logits.transpose(1, 2)[...,1:-1:2],\n",
    "            batch['input_ids'][...,2::2]\n",
    "        )\n",
    "        out['loss'] = self.alpha * out.even_loss + out.odd_loss\n",
    "        return out\n",
    "\n",
    "    def _log_loss(self, out, prefix):\n",
    "        for k in out.keys():\n",
    "            if 'loss' in k:\n",
    "                self.log(prefix + '_' + k, out[k].item())\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        out = self.forward(batch)\n",
    "        self._log_loss(out, 'train')\n",
    "        return out.loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        out = self.forward(batch)\n",
    "        self._log_loss(out, 'val')\n",
    "        return out.loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        out = self.forward(batch)\n",
    "        self._log_loss(out, 'test')\n",
    "        return out.loss\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        optimizer = optim.Adam(\n",
    "            params=self.model.parameters(),\n",
    "            lr=self.lr,\n",
    "        )\n",
    "        scheduler = transformers.get_cosine_schedule_with_warmup(\n",
    "            optimizer=optimizer,\n",
    "            num_warmup_steps=self.num_warmup_steps,\n",
    "            num_training_steps=9200 #1 epoch #self.trainer.estimated_stepping_batches,\n",
    "        )\n",
    "        return (\n",
    "            [optimizer],\n",
    "            [{\"scheduler\": scheduler, \"interval\": \"step\"}]\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "05f17a5a-fe37-44f4-8eff-304aa478867d",
   "metadata": {},
   "outputs": [],
   "source": [
    "NAME = 'GPT2-UNBALANCED'\n",
    "PROJ = 'XXXXXXXX_FUTURE_GPT2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f2fde467-b571-4cc0-8eda-c6313ddc4cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = datasets.load_from_disk('/home/XXXXXXXX/msmarco_GPT2_64tokens_full/train').with_format('torch')\n",
    "val = datasets.load_from_disk('/home/XXXXXXXX/msmarco_GPT2_64tokens_full/val').with_format('torch')\n",
    "train_loader = DataLoader(train, batch_size=512, num_workers=251)\n",
    "val_loader = DataLoader(val, batch_size=512, num_workers=111)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b624a4b6-91d3-403c-b0f6-5bd1ffe8ac37",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb_logger = WandbLogger(\n",
    "    name=NAME,\n",
    "    project=PROJ,\n",
    "    log_model=False,   # Only save checkpoints locally\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4e2009aa-4900-480c-899d-e3cd1bedacbd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "lr_monitor = LearningRateMonitor()\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=\"/home/XXXXXXXX/checkpoints\",\n",
    "    filename=NAME + \"_{global_step}_{val_loss:.2f}\",\n",
    "    every_n_epochs=1,\n",
    "    save_top_k=1,\n",
    "    monitor='val_loss',\n",
    "    mode='min',\n",
    ")\n",
    "early_stop_callback = EarlyStopping(\n",
    "    monitor='val_loss',\n",
    "    divergence_threshold=15,\n",
    "    min_delta=0.00,\n",
    "    patience=10,\n",
    "    verbose=False,\n",
    "    mode='min',\n",
    ")\n",
    "trainer = L.Trainer(\n",
    "    fast_dev_run=False,\n",
    "    logger=wandb_logger,\n",
    "    val_check_interval=.1,\n",
    "    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],\n",
    "    max_epochs=1,\n",
    "    enable_progress_bar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "56c2c6e2-7175-4a63-b6e1-ee5fb64489c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.4 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240307_004353-ltzcv2je</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/ltzcv2je' target=\"_blank\">GPT2-UNBALANCED</a></strong> to <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/ltzcv2je' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_GPT2/runs/ltzcv2je</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n"
     ]
    }
   ],
   "source": [
    "model = LitUnbalancedModel()\n",
    "wandb_logger.watch(model.model, log='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "340c982a-8565-4248-8ca3-781f96c00506",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /home/XXXXXXXX/checkpoints exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name  | Type            | Params\n",
      "------------------------------------------\n",
      "0 | model | GPT2LMHeadModel | 124 M \n",
      "------------------------------------------\n",
      "124 M     Trainable params\n",
      "0         Non-trainable params\n",
      "124 M     Total params\n",
      "497.759   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "45c495920f8940deb01fd2c728307dcf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IOPub message rate exceeded.\n",
      "The Jupyter server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--ServerApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "ServerApp.rate_limit_window=3.0 (secs)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train_loader,#loaders['train'],\n",
    "    val_dataloaders=val_loader,#loaders['val']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "635ff2b4-e128-4a17-9ab1-240de8f62bb2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3080e077-2c57-4ca9-b0c1-eef7fe2ee294",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
