{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9bdd2f0b-1a0f-4912-9827-4f2420b5237a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "\n",
    "import lightning as L\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "\n",
    "from models.utils import *\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8c4a59ad-c75f-42be-bbf2-30b52c5684d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LitBigram(L.LightningModule):\n",
    "    def __init__(self, model_name, lr=1e-4):\n",
    "        super().__init__()\n",
    "        self.model_name = model_name\n",
    "        model = get_model(model_name, precision='32')\n",
    "        self.embed = model.transformer.wte #model.model.embed_tokens\n",
    "        self.unembed = model.lm_head\n",
    "        for param in self.embed.parameters():\n",
    "            param.requires_grad=False\n",
    "        for param in self.unembed.parameters():\n",
    "            param.requires_grad=False\n",
    "        self.save_hyperparameters()\n",
    "        self.linear = nn.Linear(\n",
    "            self.embed.embedding_dim,\n",
    "            self.unembed.in_features\n",
    "        )\n",
    "        self.lr=lr\n",
    "\n",
    "    def forward(self, batch):\n",
    "        return self.unembed(self.linear(self.embed(batch['input_ids'])))\n",
    "\n",
    "    def _compute_loss(self, batch):\n",
    "        out = self.forward(batch)\n",
    "        return nn.CrossEntropyLoss()(\n",
    "            out.transpose(1, 2)[:,:,:-1],\n",
    "            batch['input_ids'][:,1:],\n",
    "        )\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        return self._compute_loss(batch)\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        loss = self._compute_loss(batch)\n",
    "        print('val loss', loss)\n",
    "        self.log('val_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        return self._compute_loss(batch)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return optim.Adam(params=self.linear.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fb9e0979-3814-4611-bc2e-342b9c3fba1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "checkpoint_callback = ModelCheckpoint(\n",
    "    dirpath='/workspace/checkpoints',\n",
    "    filename='GPT2_BIGRAM_{val_loss:.2f}',\n",
    "    every_n_epochs=1,\n",
    "    save_top_k=1,\n",
    "    monitor='val_loss',\n",
    "    mode='min',\n",
    ")\n",
    "trainer = L.Trainer(\n",
    "    val_check_interval=.2,\n",
    "    callbacks=[checkpoint_callback],\n",
    "    enable_progress_bar=True,\n",
    "    precision='32',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4a37012a-1f01-4db5-be75-9b51c0c7b4d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_GPT2_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "train = DataLoader(dataset['train'], batch_size=128)\n",
    "val = DataLoader(dataset['val'], batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "dee9ce6e-951f-44ee-a468-939101137d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LitBigram('gpt2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "620b0d67-573b-4498-924d-e895467f31f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.\n",
      "You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "Missing logger folder: /workspace/FutureGPT2/src/lightning_logs\n",
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory /workspace/checkpoints exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "\n",
      "  | Name    | Type      | Params\n",
      "--------------------------------------\n",
      "0 | embed   | Embedding | 38.6 M\n",
      "1 | unembed | Linear    | 38.6 M\n",
      "2 | linear  | Linear    | 590 K \n",
      "--------------------------------------\n",
      "590 K     Trainable params\n",
      "38.6 M    Non-trainable params\n",
      "39.2 M    Total params\n",
      "156.752   Total estimated model params size (MB)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(10.9571, device='cuda:0')\n",
      "val loss tensor(10.9588, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd2134a5355e41d3924e4802be23dca8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(5.7701, device='cuda:0')\n",
      "val loss tensor(5.7823, device='cuda:0')\n",
      "val loss tensor(5.8474, device='cuda:0')\n",
      "val loss tensor(5.8265, device='cuda:0')\n",
      "val loss tensor(5.7790, device='cuda:0')\n",
      "val loss tensor(5.7716, device='cuda:0')\n",
      "val loss tensor(5.8974, device='cuda:0')\n",
      "val loss tensor(5.7745, device='cuda:0')\n",
      "val loss tensor(5.8029, device='cuda:0')\n",
      "val loss tensor(5.9327, device='cuda:0')\n",
      "val loss tensor(5.8083, device='cuda:0')\n",
      "val loss tensor(5.8447, device='cuda:0')\n",
      "val loss tensor(5.7620, device='cuda:0')\n",
      "val loss tensor(5.8342, device='cuda:0')\n",
      "val loss tensor(5.8364, device='cuda:0')\n",
      "val loss tensor(5.8233, device='cuda:0')\n",
      "val loss tensor(5.8226, device='cuda:0')\n",
      "val loss tensor(5.7739, device='cuda:0')\n",
      "val loss tensor(5.8143, device='cuda:0')\n",
      "val loss tensor(5.8093, device='cuda:0')\n",
      "val loss tensor(5.7526, device='cuda:0')\n",
      "val loss tensor(5.8789, device='cuda:0')\n",
      "val loss tensor(5.8212, device='cuda:0')\n",
      "val loss tensor(5.8671, device='cuda:0')\n",
      "val loss tensor(5.8321, device='cuda:0')\n",
      "val loss tensor(5.8455, device='cuda:0')\n",
      "val loss tensor(5.8132, device='cuda:0')\n",
      "val loss tensor(5.7384, device='cuda:0')\n",
      "val loss tensor(5.8614, device='cuda:0')\n",
      "val loss tensor(5.8189, device='cuda:0')\n",
      "val loss tensor(5.8608, device='cuda:0')\n",
      "val loss tensor(5.8932, device='cuda:0')\n",
      "val loss tensor(5.8206, device='cuda:0')\n",
      "val loss tensor(5.7783, device='cuda:0')\n",
      "val loss tensor(5.6973, device='cuda:0')\n",
      "val loss tensor(5.8152, device='cuda:0')\n",
      "val loss tensor(5.8309, device='cuda:0')\n",
      "val loss tensor(5.9022, device='cuda:0')\n",
      "val loss tensor(5.8863, device='cuda:0')\n",
      "val loss tensor(5.8891, device='cuda:0')\n",
      "val loss tensor(5.8187, device='cuda:0')\n",
      "val loss tensor(5.8185, device='cuda:0')\n",
      "val loss tensor(5.9137, device='cuda:0')\n",
      "val loss tensor(5.8298, device='cuda:0')\n",
      "val loss tensor(5.8086, device='cuda:0')\n",
      "val loss tensor(5.8236, device='cuda:0')\n",
      "val loss tensor(5.8415, device='cuda:0')\n",
      "val loss tensor(5.7562, device='cuda:0')\n",
      "val loss tensor(5.8414, device='cuda:0')\n",
      "val loss tensor(5.8736, device='cuda:0')\n",
      "val loss tensor(5.8178, device='cuda:0')\n",
      "val loss tensor(5.9686, device='cuda:0')\n",
      "val loss tensor(5.7506, device='cuda:0')\n",
      "val loss tensor(5.7391, device='cuda:0')\n",
      "val loss tensor(5.8075, device='cuda:0')\n",
      "val loss tensor(5.8796, device='cuda:0')\n",
      "val loss tensor(5.8127, device='cuda:0')\n",
      "val loss tensor(5.8644, device='cuda:0')\n",
      "val loss tensor(5.8334, device='cuda:0')\n",
      "val loss tensor(5.8239, device='cuda:0')\n",
      "val loss tensor(5.8590, device='cuda:0')\n",
      "val loss tensor(5.7749, device='cuda:0')\n",
      "val loss tensor(5.8121, device='cuda:0')\n",
      "val loss tensor(5.7860, device='cuda:0')\n",
      "val loss tensor(5.8840, device='cuda:0')\n",
      "val loss tensor(5.8203, device='cuda:0')\n",
      "val loss tensor(5.8449, device='cuda:0')\n",
      "val loss tensor(5.8019, device='cuda:0')\n",
      "val loss tensor(5.9705, device='cuda:0')\n",
      "val loss tensor(5.8076, device='cuda:0')\n",
      "val loss tensor(5.7758, device='cuda:0')\n",
      "val loss tensor(5.9459, device='cuda:0')\n",
      "val loss tensor(5.8033, device='cuda:0')\n",
      "val loss tensor(5.7535, device='cuda:0')\n",
      "val loss tensor(5.7088, device='cuda:0')\n",
      "val loss tensor(5.7715, device='cuda:0')\n",
      "val loss tensor(5.8033, device='cuda:0')\n",
      "val loss tensor(5.8117, device='cuda:0')\n",
      "val loss tensor(5.8375, device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(5.7197, device='cuda:0')\n",
      "val loss tensor(5.7392, device='cuda:0')\n",
      "val loss tensor(5.8026, device='cuda:0')\n",
      "val loss tensor(5.7850, device='cuda:0')\n",
      "val loss tensor(5.7253, device='cuda:0')\n",
      "val loss tensor(5.7150, device='cuda:0')\n",
      "val loss tensor(5.8508, device='cuda:0')\n",
      "val loss tensor(5.7329, device='cuda:0')\n",
      "val loss tensor(5.7544, device='cuda:0')\n",
      "val loss tensor(5.8800, device='cuda:0')\n",
      "val loss tensor(5.7651, device='cuda:0')\n",
      "val loss tensor(5.7961, device='cuda:0')\n",
      "val loss tensor(5.7090, device='cuda:0')\n",
      "val loss tensor(5.7806, device='cuda:0')\n",
      "val loss tensor(5.7867, device='cuda:0')\n",
      "val loss tensor(5.7828, device='cuda:0')\n",
      "val loss tensor(5.7797, device='cuda:0')\n",
      "val loss tensor(5.7232, device='cuda:0')\n",
      "val loss tensor(5.7677, device='cuda:0')\n",
      "val loss tensor(5.7513, device='cuda:0')\n",
      "val loss tensor(5.7038, device='cuda:0')\n",
      "val loss tensor(5.8359, device='cuda:0')\n",
      "val loss tensor(5.7611, device='cuda:0')\n",
      "val loss tensor(5.8145, device='cuda:0')\n",
      "val loss tensor(5.7797, device='cuda:0')\n",
      "val loss tensor(5.7943, device='cuda:0')\n",
      "val loss tensor(5.7599, device='cuda:0')\n",
      "val loss tensor(5.6963, device='cuda:0')\n",
      "val loss tensor(5.8095, device='cuda:0')\n",
      "val loss tensor(5.7703, device='cuda:0')\n",
      "val loss tensor(5.8106, device='cuda:0')\n",
      "val loss tensor(5.8453, device='cuda:0')\n",
      "val loss tensor(5.7679, device='cuda:0')\n",
      "val loss tensor(5.7300, device='cuda:0')\n",
      "val loss tensor(5.6486, device='cuda:0')\n",
      "val loss tensor(5.7575, device='cuda:0')\n",
      "val loss tensor(5.7832, device='cuda:0')\n",
      "val loss tensor(5.8456, device='cuda:0')\n",
      "val loss tensor(5.8375, device='cuda:0')\n",
      "val loss tensor(5.8397, device='cuda:0')\n",
      "val loss tensor(5.7786, device='cuda:0')\n",
      "val loss tensor(5.7742, device='cuda:0')\n",
      "val loss tensor(5.8571, device='cuda:0')\n",
      "val loss tensor(5.7817, device='cuda:0')\n",
      "val loss tensor(5.7617, device='cuda:0')\n",
      "val loss tensor(5.7714, device='cuda:0')\n",
      "val loss tensor(5.7932, device='cuda:0')\n",
      "val loss tensor(5.7012, device='cuda:0')\n",
      "val loss tensor(5.7894, device='cuda:0')\n",
      "val loss tensor(5.8253, device='cuda:0')\n",
      "val loss tensor(5.7657, device='cuda:0')\n",
      "val loss tensor(5.9217, device='cuda:0')\n",
      "val loss tensor(5.6999, device='cuda:0')\n",
      "val loss tensor(5.6897, device='cuda:0')\n",
      "val loss tensor(5.7680, device='cuda:0')\n",
      "val loss tensor(5.8243, device='cuda:0')\n",
      "val loss tensor(5.7673, device='cuda:0')\n",
      "val loss tensor(5.8183, device='cuda:0')\n",
      "val loss tensor(5.7926, device='cuda:0')\n",
      "val loss tensor(5.7692, device='cuda:0')\n",
      "val loss tensor(5.8114, device='cuda:0')\n",
      "val loss tensor(5.7104, device='cuda:0')\n",
      "val loss tensor(5.7679, device='cuda:0')\n",
      "val loss tensor(5.7341, device='cuda:0')\n",
      "val loss tensor(5.8404, device='cuda:0')\n",
      "val loss tensor(5.7646, device='cuda:0')\n",
      "val loss tensor(5.7938, device='cuda:0')\n",
      "val loss tensor(5.7481, device='cuda:0')\n",
      "val loss tensor(5.9162, device='cuda:0')\n",
      "val loss tensor(5.7515, device='cuda:0')\n",
      "val loss tensor(5.7231, device='cuda:0')\n",
      "val loss tensor(5.8954, device='cuda:0')\n",
      "val loss tensor(5.7446, device='cuda:0')\n",
      "val loss tensor(5.7014, device='cuda:0')\n",
      "val loss tensor(5.6533, device='cuda:0')\n",
      "val loss tensor(5.7175, device='cuda:0')\n",
      "val loss tensor(5.7463, device='cuda:0')\n",
      "val loss tensor(5.7526, device='cuda:0')\n",
      "val loss tensor(5.8031, device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(5.7024, device='cuda:0')\n",
      "val loss tensor(5.7304, device='cuda:0')\n",
      "val loss tensor(5.7957, device='cuda:0')\n",
      "val loss tensor(5.7654, device='cuda:0')\n",
      "val loss tensor(5.7077, device='cuda:0')\n",
      "val loss tensor(5.7001, device='cuda:0')\n",
      "val loss tensor(5.8311, device='cuda:0')\n",
      "val loss tensor(5.7146, device='cuda:0')\n",
      "val loss tensor(5.7344, device='cuda:0')\n",
      "val loss tensor(5.8719, device='cuda:0')\n",
      "val loss tensor(5.7509, device='cuda:0')\n",
      "val loss tensor(5.7848, device='cuda:0')\n",
      "val loss tensor(5.6929, device='cuda:0')\n",
      "val loss tensor(5.7722, device='cuda:0')\n",
      "val loss tensor(5.7674, device='cuda:0')\n",
      "val loss tensor(5.7718, device='cuda:0')\n",
      "val loss tensor(5.7654, device='cuda:0')\n",
      "val loss tensor(5.6992, device='cuda:0')\n",
      "val loss tensor(5.7456, device='cuda:0')\n",
      "val loss tensor(5.7328, device='cuda:0')\n",
      "val loss tensor(5.6812, device='cuda:0')\n",
      "val loss tensor(5.8188, device='cuda:0')\n",
      "val loss tensor(5.7442, device='cuda:0')\n",
      "val loss tensor(5.7968, device='cuda:0')\n",
      "val loss tensor(5.7663, device='cuda:0')\n",
      "val loss tensor(5.7794, device='cuda:0')\n",
      "val loss tensor(5.7496, device='cuda:0')\n",
      "val loss tensor(5.6828, device='cuda:0')\n",
      "val loss tensor(5.7853, device='cuda:0')\n",
      "val loss tensor(5.7535, device='cuda:0')\n",
      "val loss tensor(5.7901, device='cuda:0')\n",
      "val loss tensor(5.8298, device='cuda:0')\n",
      "val loss tensor(5.7479, device='cuda:0')\n",
      "val loss tensor(5.7121, device='cuda:0')\n",
      "val loss tensor(5.6288, device='cuda:0')\n",
      "val loss tensor(5.7408, device='cuda:0')\n",
      "val loss tensor(5.7706, device='cuda:0')\n",
      "val loss tensor(5.8312, device='cuda:0')\n",
      "val loss tensor(5.8225, device='cuda:0')\n",
      "val loss tensor(5.8204, device='cuda:0')\n",
      "val loss tensor(5.7678, device='cuda:0')\n",
      "val loss tensor(5.7467, device='cuda:0')\n",
      "val loss tensor(5.8435, device='cuda:0')\n",
      "val loss tensor(5.7649, device='cuda:0')\n",
      "val loss tensor(5.7339, device='cuda:0')\n",
      "val loss tensor(5.7542, device='cuda:0')\n",
      "val loss tensor(5.7734, device='cuda:0')\n",
      "val loss tensor(5.6871, device='cuda:0')\n",
      "val loss tensor(5.7724, device='cuda:0')\n",
      "val loss tensor(5.8079, device='cuda:0')\n",
      "val loss tensor(5.7412, device='cuda:0')\n",
      "val loss tensor(5.9110, device='cuda:0')\n",
      "val loss tensor(5.6855, device='cuda:0')\n",
      "val loss tensor(5.6627, device='cuda:0')\n",
      "val loss tensor(5.7416, device='cuda:0')\n",
      "val loss tensor(5.8123, device='cuda:0')\n",
      "val loss tensor(5.7439, device='cuda:0')\n",
      "val loss tensor(5.8015, device='cuda:0')\n",
      "val loss tensor(5.7792, device='cuda:0')\n",
      "val loss tensor(5.7578, device='cuda:0')\n",
      "val loss tensor(5.7910, device='cuda:0')\n",
      "val loss tensor(5.6869, device='cuda:0')\n",
      "val loss tensor(5.7474, device='cuda:0')\n",
      "val loss tensor(5.7142, device='cuda:0')\n",
      "val loss tensor(5.8220, device='cuda:0')\n",
      "val loss tensor(5.7469, device='cuda:0')\n",
      "val loss tensor(5.7822, device='cuda:0')\n",
      "val loss tensor(5.7315, device='cuda:0')\n",
      "val loss tensor(5.8996, device='cuda:0')\n",
      "val loss tensor(5.7378, device='cuda:0')\n",
      "val loss tensor(5.7044, device='cuda:0')\n",
      "val loss tensor(5.8758, device='cuda:0')\n",
      "val loss tensor(5.7310, device='cuda:0')\n",
      "val loss tensor(5.6788, device='cuda:0')\n",
      "val loss tensor(5.6336, device='cuda:0')\n",
      "val loss tensor(5.7033, device='cuda:0')\n",
      "val loss tensor(5.7298, device='cuda:0')\n",
      "val loss tensor(5.7356, device='cuda:0')\n",
      "val loss tensor(5.7812, device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(5.6849, device='cuda:0')\n",
      "val loss tensor(5.7183, device='cuda:0')\n",
      "val loss tensor(5.7804, device='cuda:0')\n",
      "val loss tensor(5.7617, device='cuda:0')\n",
      "val loss tensor(5.6980, device='cuda:0')\n",
      "val loss tensor(5.6940, device='cuda:0')\n",
      "val loss tensor(5.8232, device='cuda:0')\n",
      "val loss tensor(5.7015, device='cuda:0')\n",
      "val loss tensor(5.7292, device='cuda:0')\n",
      "val loss tensor(5.8589, device='cuda:0')\n",
      "val loss tensor(5.7397, device='cuda:0')\n",
      "val loss tensor(5.7678, device='cuda:0')\n",
      "val loss tensor(5.6777, device='cuda:0')\n",
      "val loss tensor(5.7635, device='cuda:0')\n",
      "val loss tensor(5.7524, device='cuda:0')\n",
      "val loss tensor(5.7541, device='cuda:0')\n",
      "val loss tensor(5.7562, device='cuda:0')\n",
      "val loss tensor(5.6919, device='cuda:0')\n",
      "val loss tensor(5.7451, device='cuda:0')\n",
      "val loss tensor(5.7196, device='cuda:0')\n",
      "val loss tensor(5.6711, device='cuda:0')\n",
      "val loss tensor(5.8135, device='cuda:0')\n",
      "val loss tensor(5.7369, device='cuda:0')\n",
      "val loss tensor(5.7955, device='cuda:0')\n",
      "val loss tensor(5.7538, device='cuda:0')\n",
      "val loss tensor(5.7750, device='cuda:0')\n",
      "val loss tensor(5.7385, device='cuda:0')\n",
      "val loss tensor(5.6682, device='cuda:0')\n",
      "val loss tensor(5.7740, device='cuda:0')\n",
      "val loss tensor(5.7388, device='cuda:0')\n",
      "val loss tensor(5.7760, device='cuda:0')\n",
      "val loss tensor(5.8213, device='cuda:0')\n",
      "val loss tensor(5.7412, device='cuda:0')\n",
      "val loss tensor(5.7048, device='cuda:0')\n",
      "val loss tensor(5.6235, device='cuda:0')\n",
      "val loss tensor(5.7320, device='cuda:0')\n",
      "val loss tensor(5.7605, device='cuda:0')\n",
      "val loss tensor(5.8154, device='cuda:0')\n",
      "val loss tensor(5.8068, device='cuda:0')\n",
      "val loss tensor(5.8130, device='cuda:0')\n",
      "val loss tensor(5.7558, device='cuda:0')\n",
      "val loss tensor(5.7346, device='cuda:0')\n",
      "val loss tensor(5.8326, device='cuda:0')\n",
      "val loss tensor(5.7575, device='cuda:0')\n",
      "val loss tensor(5.7256, device='cuda:0')\n",
      "val loss tensor(5.7433, device='cuda:0')\n",
      "val loss tensor(5.7674, device='cuda:0')\n",
      "val loss tensor(5.6738, device='cuda:0')\n",
      "val loss tensor(5.7683, device='cuda:0')\n",
      "val loss tensor(5.7924, device='cuda:0')\n",
      "val loss tensor(5.7329, device='cuda:0')\n",
      "val loss tensor(5.8981, device='cuda:0')\n",
      "val loss tensor(5.6730, device='cuda:0')\n",
      "val loss tensor(5.6483, device='cuda:0')\n",
      "val loss tensor(5.7301, device='cuda:0')\n",
      "val loss tensor(5.8013, device='cuda:0')\n",
      "val loss tensor(5.7287, device='cuda:0')\n",
      "val loss tensor(5.7926, device='cuda:0')\n",
      "val loss tensor(5.7673, device='cuda:0')\n",
      "val loss tensor(5.7458, device='cuda:0')\n",
      "val loss tensor(5.7851, device='cuda:0')\n",
      "val loss tensor(5.6777, device='cuda:0')\n",
      "val loss tensor(5.7409, device='cuda:0')\n",
      "val loss tensor(5.7030, device='cuda:0')\n",
      "val loss tensor(5.8173, device='cuda:0')\n",
      "val loss tensor(5.7398, device='cuda:0')\n",
      "val loss tensor(5.7705, device='cuda:0')\n",
      "val loss tensor(5.7245, device='cuda:0')\n",
      "val loss tensor(5.8836, device='cuda:0')\n",
      "val loss tensor(5.7181, device='cuda:0')\n",
      "val loss tensor(5.6952, device='cuda:0')\n",
      "val loss tensor(5.8634, device='cuda:0')\n",
      "val loss tensor(5.7203, device='cuda:0')\n",
      "val loss tensor(5.6668, device='cuda:0')\n",
      "val loss tensor(5.6241, device='cuda:0')\n",
      "val loss tensor(5.6942, device='cuda:0')\n",
      "val loss tensor(5.7127, device='cuda:0')\n",
      "val loss tensor(5.7224, device='cuda:0')\n",
      "val loss tensor(5.7649, device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "val loss tensor(5.6876, device='cuda:0')\n",
      "val loss tensor(5.7162, device='cuda:0')\n",
      "val loss tensor(5.7833, device='cuda:0')\n",
      "val loss tensor(5.7543, device='cuda:0')\n",
      "val loss tensor(5.6928, device='cuda:0')\n",
      "val loss tensor(5.6882, device='cuda:0')\n",
      "val loss tensor(5.8074, device='cuda:0')\n",
      "val loss tensor(5.7055, device='cuda:0')\n",
      "val loss tensor(5.7195, device='cuda:0')\n",
      "val loss tensor(5.8589, device='cuda:0')\n",
      "val loss tensor(5.7359, device='cuda:0')\n",
      "val loss tensor(5.7686, device='cuda:0')\n",
      "val loss tensor(5.6822, device='cuda:0')\n",
      "val loss tensor(5.7585, device='cuda:0')\n",
      "val loss tensor(5.7517, device='cuda:0')\n",
      "val loss tensor(5.7564, device='cuda:0')\n",
      "val loss tensor(5.7535, device='cuda:0')\n",
      "val loss tensor(5.6923, device='cuda:0')\n",
      "val loss tensor(5.7397, device='cuda:0')\n",
      "val loss tensor(5.7183, device='cuda:0')\n",
      "val loss tensor(5.6676, device='cuda:0')\n",
      "val loss tensor(5.8068, device='cuda:0')\n",
      "val loss tensor(5.7356, device='cuda:0')\n",
      "val loss tensor(5.7885, device='cuda:0')\n",
      "val loss tensor(5.7452, device='cuda:0')\n",
      "val loss tensor(5.7684, device='cuda:0')\n",
      "val loss tensor(5.7357, device='cuda:0')\n",
      "val loss tensor(5.6627, device='cuda:0')\n",
      "val loss tensor(5.7753, device='cuda:0')\n",
      "val loss tensor(5.7353, device='cuda:0')\n",
      "val loss tensor(5.7793, device='cuda:0')\n",
      "val loss tensor(5.8146, device='cuda:0')\n",
      "val loss tensor(5.7279, device='cuda:0')\n",
      "val loss tensor(5.6954, device='cuda:0')\n",
      "val loss tensor(5.6184, device='cuda:0')\n",
      "val loss tensor(5.7361, device='cuda:0')\n",
      "val loss tensor(5.7480, device='cuda:0')\n",
      "val loss tensor(5.8089, device='cuda:0')\n",
      "val loss tensor(5.7992, device='cuda:0')\n",
      "val loss tensor(5.8114, device='cuda:0')\n",
      "val loss tensor(5.7536, device='cuda:0')\n",
      "val loss tensor(5.7365, device='cuda:0')\n",
      "val loss tensor(5.8305, device='cuda:0')\n",
      "val loss tensor(5.7580, device='cuda:0')\n",
      "val loss tensor(5.7152, device='cuda:0')\n",
      "val loss tensor(5.7368, device='cuda:0')\n",
      "val loss tensor(5.7612, device='cuda:0')\n",
      "val loss tensor(5.6765, device='cuda:0')\n",
      "val loss tensor(5.7629, device='cuda:0')\n",
      "val loss tensor(5.7911, device='cuda:0')\n",
      "val loss tensor(5.7289, device='cuda:0')\n",
      "val loss tensor(5.9079, device='cuda:0')\n",
      "val loss tensor(5.6734, device='cuda:0')\n",
      "val loss tensor(5.6401, device='cuda:0')\n",
      "val loss tensor(5.7273, device='cuda:0')\n",
      "val loss tensor(5.7988, device='cuda:0')\n",
      "val loss tensor(5.7229, device='cuda:0')\n",
      "val loss tensor(5.7927, device='cuda:0')\n",
      "val loss tensor(5.7682, device='cuda:0')\n",
      "val loss tensor(5.7423, device='cuda:0')\n",
      "val loss tensor(5.7758, device='cuda:0')\n",
      "val loss tensor(5.6728, device='cuda:0')\n",
      "val loss tensor(5.7342, device='cuda:0')\n",
      "val loss tensor(5.7059, device='cuda:0')\n",
      "val loss tensor(5.8131, device='cuda:0')\n",
      "val loss tensor(5.7379, device='cuda:0')\n",
      "val loss tensor(5.7597, device='cuda:0')\n",
      "val loss tensor(5.7217, device='cuda:0')\n",
      "val loss tensor(5.8850, device='cuda:0')\n",
      "val loss tensor(5.7221, device='cuda:0')\n",
      "val loss tensor(5.6906, device='cuda:0')\n",
      "val loss tensor(5.8681, device='cuda:0')\n",
      "val loss tensor(5.7137, device='cuda:0')\n",
      "val loss tensor(5.6662, device='cuda:0')\n",
      "val loss tensor(5.6166, device='cuda:0')\n",
      "val loss tensor(5.6943, device='cuda:0')\n",
      "val loss tensor(5.7112, device='cuda:0')\n",
      "val loss tensor(5.7163, device='cuda:0')\n",
      "val loss tensor(5.7681, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    model=model,\n",
    "    train_dataloaders=train,\n",
    "    val_dataloaders=val,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c4deeb2-4d33-4a99-b533-8e8476420be6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
