{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9c00c716-e60b-4c8e-abb8-b6835228c89f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.12.1-git20200711.33e2d80-dfsg1-0.6 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n",
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 0.1.43ubuntu1 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n",
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.1build1 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import wandb\n",
    "import models.future_model\n",
    "from lightning.pytorch.loggers import WandbLogger\n",
    "import os\n",
    "from lightning import Trainer\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from models.necks import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2ea9a427-5ed9-49d7-a45a-728ab11e1ceb",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "07c243cc-217b-43f1-ae96-64ce8f13bf05",
   "metadata": {},
   "outputs": [],
   "source": [
    "FatNeck = lambda h, v: LSTMNeck(h, v, 1, True, 128, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c646fa13-ba51-4f6b-8c97-1f44205a68d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/XXXXXXXX/.netrc\n"
     ]
    }
   ],
   "source": [
    "wandb.login(key=os.environ['WANDB_API_KEY'], relogin=True)\n",
    "wandb_logger = WandbLogger(\n",
    "    name='base_run',\n",
    "    project='XXXXXXXX_FutureGPT2',\n",
    "    log_model='all',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dc184083-bb53-4677-83c8-1e92409c757b",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk('/corpus/msmarco/msmarco_GPT2_64tokens_100k').with_format('torch')\n",
    "loaders = {\n",
    "    split: DataLoader(dataset[split], batch_size=128) \n",
    "    for split in ['train', 'val', 'test']\n",
    "}\n",
    "\n",
    "# use small val set; o/w too slow\n",
    "loaders['val'] =  DataLoader(\n",
    "    datasets.load_from_disk('/corpus/msmarco/msmarco_GPT2_64tokens_100k').with_format('torch')['val'],\n",
    "    batch_size=128,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b4e10b58-19f5-47ec-b4d4-9bf3a52d7d82",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_callback = LearningRateMonitor()\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=\"/checkpoints/\",\n",
    "    # can't zero pad global_step b/c the logger casts it to a float >:(((\n",
    "    filename=\"gpt2-test_run-{epoch:02d}-{global_step:.0f}-{train_total_loss:.2f}\",\n",
    "    #every_n_epochs=0.1,\n",
    "    #every_n_train_steps=10_000,\n",
    "    every_n_train_steps=10,\n",
    "    #save_last=True,\n",
    "    save_top_k=-1,\n",
    "    monitor='global_step',\n",
    "    mode='max',\n",
    "    save_on_train_epoch_end=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "de04e6f3-081d-4639-a668-65f705c12f57",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.12.1-git20200711.33e2d80-dfsg1-0.6 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n",
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 0.1.43ubuntu1 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n",
      "/usr/lib/python3/dist-packages/pkg_resources/__init__.py:116: PkgResourcesDeprecationWarning: 1.1build1 is an invalid version and will not be supported in a future release\n",
      "  warnings.warn(\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    #fast_dev_run=True,\n",
    "    logger=wandb_logger,\n",
    "    # default_root_dir='..',\n",
    "    val_check_interval=0.1,\n",
    "    callbacks=[checkpoint_callback, lr_callback],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7c78be7c-c5de-4ff2-b751-36ebf51f6afb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.0 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.12"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20231110_013921-bvz047e2</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2/runs/bvz047e2' target=\"_blank\">base_run</a></strong> to <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2/runs/bvz047e2' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FutureGPT2/runs/bvz047e2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n"
     ]
    }
   ],
   "source": [
    "model = models.future_model.LitFutureModel(FatNeck)\n",
    "wandb_logger.watch(model.future_neck, log='all')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7349056e-6f07-4b06-a633-6e845b298bfd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:630: Checkpoint directory /checkpoints/ exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name        | Type             | Params\n",
      "-------------------------------------------------\n",
      "0 | base_model  | GPT2LMHeadModel  | 124 M \n",
      "1 | future_neck | LSTMNeck         | 1.1 M \n",
      "2 | loss_func   | CrossEntropyLoss | 0     \n",
      "-------------------------------------------------\n",
      "1.1 M     Trainable params\n",
      "124 M     Non-trainable params\n",
      "125 M     Total params\n",
      "502.096   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |                                                                                            …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4d4b1a7b7c6244a189577836df2d972a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |                                                                                                   …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:211: You called `self.log('global_step', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                                 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e68a62a89cd24718946efa3cd61e32dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |                                                                                                 …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(model=model, train_dataloaders=loaders['train'], val_dataloaders=loaders['val'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eae957dc-85ac-4f91-a199-27706f372f17",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'np' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;66;03m#neck = model.future_neck.weight.data\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m neck \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/checkpoints/neck.npy\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      4\u001b[0m plt\u001b[38;5;241m.\u001b[39mimshow(np\u001b[38;5;241m.\u001b[39mlog(neck))\n",
      "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)\n",
      "wandb: ERROR Error while calling W&B API: failed to find run XXXXXXXX_FutureGPT2/bvz047e2 (<Response [404]>)\n",
      "[rank: 0] Received SIGTERM: 15\n",
      "wandb: ERROR Error while calling W&B API: run XXXXXXXX_FutureGPT2/bvz047e2 not found during createRunFiles (<Response [404]>)\n",
      "wandb: ERROR Error while calling W&B API: run XXXXXXXX_FutureGPT2/bvz047e2 not found during createRunFiles (<Response [404]>)\n",
      "wandb: ERROR Error while calling W&B API: run XXXXXXXX_FutureGPT2/bvz047e2 not found during createRunFiles (<Response [404]>)\n",
      "wandb: ERROR Error while calling W&B API: run XXXXXXXX_FutureGPT2/bvz047e2 not found during createRunFiles (<Response [404]>)\n",
      "wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "#neck = model.future_neck.weight.data\n",
    "neck = np.load('/checkpoints/neck.npy')\n",
    "plt.imshow(np.log(neck))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d4ce045-932a-4d1b-8339-c80329c4d0c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(neck)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd17fc94-0f7e-4553-aa27-a2d3ed0d90c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "neck"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d1fc879-4407-4ce8-945a-e41861b4cd9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "#np.save('/checkpoints/neck.npy', neck.cpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c320b423-ae63-4556-ba5c-ae10355f78b6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c1de02-a81b-479b-af53-0be4a2933dff",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_model = models.future_model.LitFutureModel.load_from_checkpoint('/home/XXXXXXXX/FutureGPT2/src/XXXXXXXX_FutureGPT2/8wd5g5pv/checkpoints/epoch=4-step=3125.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b2eea8-84cc-48c6-8e5e-c43a7f79d35e",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_model.future_neck.weight.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e996b3-a794-4d1f-842f-9e5081c21f0e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
