{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "045643f1-ae99-4c92-a2be-374a766ce69a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "\n",
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import GPT2Config, GPT2Model\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from itertools import repeat\n",
    "\n",
    "from models.regression_model import *\n",
    "from data.synthetic import *\n",
    "from models.myopic_model import to_myopic_gpt2\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "11b19d00-23c6-44cd-92f1-290f6e87792b",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.get_device_capability()[0] >= 8:\n",
    "    torch.set_float32_matmul_precision('high')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "184a5259-1a02-4257-8c01-cc34ea3fc7d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_data_params = {\n",
    "    'seq_len': 64,\n",
    "}\n",
    "free_data_params = {\n",
    "    'conv': 10,\n",
    "    'shift': 1,\n",
    "    'p': 0.3,\n",
    "}\n",
    "data_params = {**fixed_data_params, **free_data_params}\n",
    "\n",
    "fixed_model_params = {\n",
    "    'attn_pdrop': 0,\n",
    "    'embd_pdrop': 0,\n",
    "    'resid_pdrop': 0,\n",
    "    'n_inner': None, # MLP inner size; None=4x\n",
    "    'n_positions': 64,\n",
    "    'in_dim': 2,\n",
    "}\n",
    "free_model_params = {\n",
    "    'lr': 1e-3,\n",
    "    'n_embd': 8,\n",
    "    #'n_head': 2,\n",
    "    'n_head': 1,\n",
    "    'n_layer': 2,\n",
    "    'activation_function': 'relu',\n",
    "}\n",
    "model_params = {**fixed_model_params, **free_model_params}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b8c16c8c-a966-47a0-a087-ed16a323db88",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/XXXXXXXX/.netrc\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.login(key='XXXXXXXX', relogin=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "359f5fe0-2dd9-4249-b0dd-f737dbd843bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#val = DataLoader(SyntheticSeqDataset(size=10000, **data_params), batch_size=512, num_workers=95)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "54d1dd71-3dfd-4984-9785-2b77bc279feb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.5 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.16.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20240329_221507-xsi1xxcc</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC/runs/xsi1xxcc' target=\"_blank\">SYNTH-GPT2-COS_conv-10_shift-1_p-0.3_lr-0.001_n_embd-8_n_head-1_n_layer-2_activation_function-relu</a></strong> to <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC/runs/xsi1xxcc' target=\"_blank\">https://wandb.ai/XXXXXXXX/XXXXXXXX_FUTURE_SYNTHETIC/runs/xsi1xxcc</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: logging graph, to disable use `wandb.watch(log_graph=False)`\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /workspace/checkpoints exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "\n",
      "  | Name    | Type      | Params\n",
      "--------------------------------------\n",
      "0 | model   | GPT2Model | 2.3 K \n",
      "1 | loss_fn | HuberLoss | 0     \n",
      "--------------------------------------\n",
      "2.3 K     Trainable params\n",
      "0         Non-trainable params\n",
      "2.3 K     Total params\n",
      "0.009     Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NUM TRAINING STEPS 58594\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b4588d9a4d4455c82e774aa46352e67",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=1` reached.\n"
     ]
    }
   ],
   "source": [
    "train = DataLoader(SyntheticSeqDataset(size=30_000_000, **data_params), batch_size=512, num_workers=95)\n",
    "#train = DataLoader(SyntheticSeqDataset(size=10_000, **data_params), batch_size=512, num_workers=95)\n",
    "BETA=0.0\n",
    "NAME = '_'.join(\n",
    "    [f'SYNTH-GPT2-COS'] + \n",
    "    [f'{k}-{v}' for k, v in {**free_data_params, **free_model_params}.items()]\n",
    ")\n",
    "PROJ = 'XXXXXXXX_FUTURE_SYNTHETIC'\n",
    "wandb_logger = WandbLogger(\n",
    "    name=NAME,\n",
    "    project=PROJ,\n",
    "    log_model=False,   # Only save checkpoints locally\n",
    ")\n",
    "lr_monitor = LearningRateMonitor()\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath=\"/workspace/checkpoints\",\n",
    "    filename=NAME + \"_{global_step}_{train_loss:.2f}\",\n",
    "    every_n_epochs=1,\n",
    "    save_top_k=1,\n",
    "    monitor='train_loss',\n",
    "    mode='min',\n",
    ")\n",
    "early_stop_callback = EarlyStopping(\n",
    "    monitor='train_loss',\n",
    "    divergence_threshold=10000,\n",
    "    min_delta=0.00,\n",
    "    patience=100000,\n",
    "    verbose=False,\n",
    "    mode='min',\n",
    ")\n",
    "trainer = L.Trainer(\n",
    "    fast_dev_run=False,\n",
    "    logger=wandb_logger,\n",
    "    val_check_interval=1.0,\n",
    "    #check_val_every_n_epoch=5,\n",
    "    callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],\n",
    "    max_epochs=1,\n",
    "    enable_progress_bar=True,\n",
    ")\n",
    "model = LitGPT2RegModel(**model_params)\n",
    "wandb_logger.watch(model.model, log='all')\n",
    "# model = to_myopic_gpt2(\n",
    "#     model,\n",
    "#     [None for _ in range(12)],\n",
    "#     beta=BETA,\n",
    "# )\n",
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train,\n",
    "    #val_dataloaders=val,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e87ccfc6-4601-43f9-a53b-a72802c3631f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
