{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "a4cb74a6-1b1a-43eb-b921-692e6fc94dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('/workspace/FutureGPT2/src/')\n",
    "os.environ['HF_HOME'] = '/workspace/cache/huggingface/'\n",
    "\n",
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from datasets import load_dataset\n",
    "from itertools import islice\n",
    "\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor\n",
    "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
    "from lightning.pytorch.loggers import WandbLogger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d9868a2-654d-44f4-a2e4-e8771a8bd27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/train').with_format('torch')\n",
    "val = datasets.load_from_disk('/workspace/corpus/msmarco/msmarco_GPT2_64tokens_full/val').with_format('torch', device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "efc084fa-77a7-419d-9dae-6156464f020c",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_loader = DataLoader(train, batch_size=128, num_workers=50)\n",
    "val_loader = DataLoader(val, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "fb7b47e8-460f-471b-a6c4-6d3247784036",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt_tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "GPTToken = {v: k for k, v in gpt_tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9ff7279e-ccae-47ed-a720-1cbe6900545c",
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b2a9f348-e55c-47d5-9bf5-e7982331435d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce5c5fac1c2e4ff5a09b7cce71372c39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8f4c9d516ee47f3986da30c361fcf8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pythia = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-410m-deduped').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "03501a39-72b3-4e5d-822a-941d06de66d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "pythia_tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-410m-deduped')\n",
    "PythiaToken = {v: k for k, v in pythia_tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "080f5c0c-6dc2-4f3e-966e-e54cbf5f0733",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7fa4f7742b90>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "70b3499f-ee98-49d4-8653-917836de78fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99e3f5e688164d9a85da27ac6728961b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/405 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "total = 0\n",
    "loss = 0\n",
    "for batch in tqdm(iter(val_loader)):\n",
    "    del batch['text']\n",
    "    del batch['id']\n",
    "    out = gpt(**batch, labels=batch['input_ids'])\n",
    "    loss += out.loss.item()\n",
    "    total += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5ed248ab-14ee-439c-bb71-44f9cfbadb31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.691645735870173"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0e6c0635-8cc1-48b6-879e-e780d656809d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state = torch.load('/workspace/checkpoints/GPT2-MSMARCO-COSINE_global_step=9099.0_val_loss=3.28.ckpt')['state_dict']\n",
    "own = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')\n",
    "own.load_state_dict({'.'.join(k.split('.')[1:]): v for k, v in state.items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b5d3804e-ca13-40d5-9e46-03993be518a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state = torch.load('/workspace/checkpoints/GPT2-MYOPIC-CUTGRAD_val_myopic_loss=3.73.ckpt')['state_dict']\n",
    "myopic = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')\n",
    "myopic.load_state_dict({'.'.join(k.split('.')[1:]): v for k, v in state.items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "482efe3e-b77f-48e5-bfed-107288d539af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2cfd9149244406c82228134237ac1bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/405 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "total = 0\n",
    "loss = 0\n",
    "for batch in tqdm(iter(val_loader)):\n",
    "    del batch['text']\n",
    "    del batch['id']\n",
    "    out = own(**batch, labels=batch['input_ids'])\n",
    "    loss += out.loss.item()\n",
    "    total += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f1dfdf4f-d480-4f0b-a9de-87fa46c2ca74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.2765459608148646"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "805f937c-dce9-4b69-a514-f4e7b05502bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "evals = [\n",
    "    'It was a dark and stormy',\n",
    "    'Anna was born in Hungary, so her first language is',\n",
    "    #'The store sells only red pens and blue books, so I bought a',\n",
    "    #'The color of the sky is',\n",
    "    #'One plus one is',\n",
    "    'The director of the Manhattan project was J. Robert Opp',\n",
    "    'These are my two friends Tom Smith and Bob Jones. Of the two, my favorite is',\n",
    "    #'It was a dark and stormy',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "0413a144-c388-4181-aac2-1a11d392431b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "THEIRS\n",
      "47 tokens\n",
      "It was a dark and stormy|| night. The wind was blowing, and the clouds were falling. The wind was blowing, and the clouds were falling. The wind was blowing, and the clouds were falling. The wind was blowing,\n",
      "MSMARCO\n",
      "47 tokens\n",
      "It was a dark and stormy|| day, and the sun was shining. The sun was shining, and the sun was shining. The sun was shining, and the sun was shining. The sun was shining, and the sun was shining\n",
      "MYOPIC\n",
      "47 tokens\n",
      "It was a dark and stormy|| day in the middle of the night. The first day of the month was the first day of the month, and the second day of the month. The first day of the month is the day of\n",
      "PYTHIA\n",
      "47 tokens\n",
      "It was a dark and stormy|| night, and the wind was howling.  \"I'm sorry,\" he said. \"I'm sorry.\"  \"I'm sorry,\" she said. \"I'm sorry.\" \n",
      "\n",
      "THEIRS\n",
      "51 tokens\n",
      "Anna was born in Hungary, so her first language is|| Hungarian. She is a very good student, and she is very good at math. She is very good at reading. She is very good at reading. She is very good at reading. She is\n",
      "MSMARCO\n",
      "51 tokens\n",
      "Anna was born in Hungary, so her first language is|| English. She is the only child of a Jewish family. She is the only child of a Jewish family. She is the only child of a Jewish family. She is the only child of a Jewish\n",
      "MYOPIC\n",
      "51 tokens\n",
      "Anna was born in Hungary, so her first language is|| English. She is the mother of the Germanic languages. She is the mother of the Germanic languages. She is the mother of the Germanic family. She is the mother of the Germanic\n",
      "PYTHIA\n",
      "51 tokens\n",
      "Anna was born in Hungary, so her first language is|| Hungarian. She is a very bright child, and she is very curious. She loves to play with dolls and dolls love to play with dolls. She loves to play with dolls and\n",
      "\n",
      "THEIRS\n",
      "51 tokens\n",
      "The director of the Manhattan project was J. Robert Opp||enheimer, who was also a friend of the president.  The project was to be a joint venture between the two companies, which would be led by the former president. Oppenheimer was\n",
      "MSMARCO\n",
      "51 tokens\n",
      "The director of the Manhattan project was J. Robert Opp||enheimer, who was the first to be elected president of the United States. The director of the Manhattan Project was the first to be elected president of the United States. The director of the Manhattan Project\n",
      "MYOPIC\n",
      "51 tokens\n",
      "The director of the Manhattan project was J. Robert Opp||enheimer, who was the first to have a new project. The project was completed in the early 1990s and was the first project to be completed in the early 1990s. The project was completed\n",
      "PYTHIA\n",
      "51 tokens\n",
      "The director of the Manhattan project was J. Robert Opp||enheimer, who was also the director of the Manhattan Project.  The Manhattan Project was a massive project to build the first atomic bomb. The Manhattan Project was a massive project to build the first\n",
      "\n",
      "THEIRS\n",
      "58 tokens\n",
      "These are my two friends Tom Smith and Bob Jones. Of the two, my favorite is|| Bob Jones. He's a great guy. He's a great guy. He's a great guy. He's a great guy. He's a great guy. He's a great guy. He\n",
      "MSMARCO\n",
      "58 tokens\n",
      "These are my two friends Tom Smith and Bob Jones. Of the two, my favorite is|| the one that I love. I love the name, but I love the name. I love the name, but I love the name. I love the name, but I love the name. I\n",
      "MYOPIC\n",
      "58 tokens\n",
      "These are my two friends Tom Smith and Bob Jones. Of the two, my favorite is|| the one who is the one who is the one who is the one who is the one who is the one who is the one who is the one who is the one who is the one who is\n",
      "PYTHIA\n",
      "58 tokens\n",
      "These are my two friends Tom Smith and Bob Jones. Of the two, my favorite is|| Bob Jones. He is a great guy. He is a great guy. He is a great guy. He is a great guy. He is a great guy. He is a great guy. He\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for s in evals:\n",
    "    for model, name in zip([gpt, own, myopic, pythia], ['THEIRS', 'MSMARCO', 'MYOPIC', 'PYTHIA']):\n",
    "        tokenizer = pythia_tokenizer if name == 'PYTHIA' else gpt_tokenizer\n",
    "        Token = PythiaToken if name == 'PYTHIA' else GPTToken\n",
    "        input = tokenizer(\n",
    "            s, \n",
    "            return_tensors='pt'\n",
    "        ).to('cuda')\n",
    "        out = model.generate(\n",
    "            **input,\n",
    "            max_new_tokens=40,\n",
    "            pad_token_id=tokenizer.eos_token_id,\n",
    "            temperature=0,\n",
    "        )\n",
    "        print(name)\n",
    "        print(out.shape[1], 'tokens')\n",
    "        out = ''.join([Token[x].replace('Ġ', ' ').replace('Ċ', ' ') for x in list(out.flatten().cpu().numpy())])\n",
    "        out = out.replace(s, s + '||')\n",
    "        print(out)\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "835af7ce-89dd-4767-b412-576d7548f5e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Alice lives in France, so she speaks to her. She is a very good-looking woman, and she is a very good-looking woman. She is a very good-looking woman'"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = own.generate(\n",
    "    **input,\n",
    "    max_new_tokens=30,\n",
    ")\n",
    "''.join([Token[x].replace('Ġ', ' ') for x in list(out.flatten().cpu().numpy())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "865f7add-1d56-44ed-be11-6e43e01d4b17",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Alice lives in France, so she speaks to the world of the world. He is the son of a man who is the son of a man who is the son of a man. He'"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out = myopic.generate(\n",
    "    **input,\n",
    "    max_new_tokens=30,\n",
    ")\n",
    "''.join([Token[x].replace('Ġ', ' ') for x in list(out.flatten().cpu().numpy())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd85604d-eb86-4e44-96f5-5998d64241f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f25f130-f0f1-4116-8461-fe78e10bdfd5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddcb11a6-ae73-4cd0-8410-a83e3a0b0ba5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c572e6-0901-4bcc-a47f-71ebc4363b1f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4bc1eed9-df64-44ff-94df-b8b710a2d173",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXXXX/.local/lib/python3.10/site-packages/datasets/load.py:1429: FutureWarning: The repository for openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/openwebtext\n",
      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21de903c907f4c378f06e43e5ae5a9de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3756f309c1e340268b6216e58615556f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/7.33k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f18e64df84004dd78b1eeecbe5f9ff06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/633M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c2e1258664374d269edbd4a66118c8d4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/629M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3959a8fe632843deb2c134daa75ba522",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/629M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0007efd08985476487674a75b9325e8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/628M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0a7de955ce794d78ac37b5fc7e46eb63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/627M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da4a38dc74a74b7392e7e0da88b6b5d0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/630M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13338723478a41e883b6043ce0006cb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/626M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c6814eb556c49f3be3ccfe76578f28a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/625M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81321e8d51ac4bc4a8de5d6d32bc89f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/625M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a75a4ddfa5b47fbab3d30b6d5bb1fcd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/626M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89d5e4ed193641c2b3b9daf13ae2cd28",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/625M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ddde352da47402d8c6caf33ccaba17b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/625M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbb859c2b62140d48c7517b815004119",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/624M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83c69305fbaa4460b1fca28d7966eb20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/629M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5f3fe58fc83a464e8fe085ddceef2d30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/627M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8662126a3d7840c482ddbb9959c6fd7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/621M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9f162bfca214ca68b67182d45437484",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/619M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7efbb1124fbb482097af2bc52714d220",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/619M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17fddee381a64da988a59089f8b8cbb6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/618M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a938ee31f654747962b4c66c949c925",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/619M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "25e35e745b5641bb8df936c87d55444f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/377M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be7200ace6344bc3aaa9b3031b2f273b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "owt = load_dataset('openwebtext', split='train[:100000]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "23d9d76e-72b8-4775-b78a-ebdb0a22aced",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0559016ffaa64b659dd37af6e6d00ab3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PYTHIA 3.1627233052253723\n"
     ]
    }
   ],
   "source": [
    "for model, name in [(pythia, 'PYTHIA')]:#zip([gpt, own, myopic], ['THEIRS', 'OURS', 'MYOPIC']):\n",
    "    total = 0\n",
    "    loss = 0\n",
    "    tokenizer = pythia_tokenizer if name == 'PYTHIA' else gpt_tokenizer\n",
    "    for batch in tqdm(islice(iter(owt_loader), 100), total=100):\n",
    "        input = tokenizer(batch['text'], return_tensors='pt', truncation=True, max_length=64).to('cuda')\n",
    "        out = model(**input, labels=input['input_ids'])\n",
    "        loss += out.loss.item()\n",
    "        total += 1\n",
    "    print(name, loss / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be06dc04-9c12-4555-b3a8-08d23aad6139",
   "metadata": {},
   "outputs": [],
   "source": [
    "total = 0\n",
    "loss = 0\n",
    "for batch in tqdm(iter(owt_loader)):\n",
    "    input = tokenizer(batch['text'], return_tensors='pt', truncation=True, max_length=64).to('cuda')\n",
    "    out = gpt(**input, labels=input['input_ids'])\n",
    "    loss += out.loss.item()\n",
    "    total += 1\n",
    "loss / total"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
