{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3dc2851a-fd97-4f32-9964-104c970cbd58",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "\n",
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import GPT2Config, GPT2Model\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from itertools import repeat, product\n",
    "from matplotlib import pyplot as plt\n",
    "import matplotlib\n",
    "\n",
    "from models.regression_model import *\n",
    "from data.synthetic import *\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "efa064b7-3cba-4436-a5b5-2bea4fb97163",
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_data_params = {\n",
    "    'seq_len': 64,\n",
    "}\n",
    "free_data_params = {\n",
    "    'conv': 10,\n",
    "    'shift': 1,\n",
    "    'p': 1,\n",
    "}\n",
    "data_params = {**fixed_data_params, **free_data_params}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "093fc7a6-2f76-4c75-89d1-f92f3615ac32",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = DataLoader(\n",
    "    SyntheticSeqDataset(size=50_000, **data_params), \n",
    "    batch_size=512, \n",
    "    #num_workers=20,\n",
    "    drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3cee7c7f-e1c8-47b4-8ee1-9fb379b31f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = '/workspace/checkpoints/SYNTH-GPT2-MYOPIC-COS_BETA0_conv-10_shift-1_p-1_lr-0.001_n_embd-128_n_head-2_n_layer-3_activation_function-relu_global_step=58593.0_train_loss=1.24.ckpt'\n",
    "model = LitGPT2RegModel.load_from_checkpoint(ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "41291e26-98c4-4aad-8406-f93c468c7146",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f63d81d0700>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "9c53831f-06ae-492d-bfa3-892e7c60e50d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "087d64eb69434a2983ea8c48bbed9eb9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/97 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "loss_fn = nn.HuberLoss()\n",
    "loss = 0\n",
    "total = 0\n",
    "for batch in tqdm(iter(ds)):\n",
    "    loss += loss_fn(torch.zeros(batch['y'].shape), batch['y']).item()\n",
    "    total += 1\n",
    "loss /= total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "e835785a-e530-4678-a070-1d85ad284815",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.2569321160463942"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "cd11762d-1d5a-4017-8b80-b4fed606bdc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(4.7274)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.var(batch['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b83fb3a2-440e-4549-a168-fd80a479ee4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(2.3649)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mean(0.5 * batch['y']**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79f54ca-a0e6-4650-ae16-762ac159b3d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
