{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "wd = \"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/XXXX-40\"\n",
    "sys.path.append(str(wd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/miniconda_frontier/envs/lit-gpt/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# Yoinked from the tinyllama.py script in lit-gpt main\n",
    "\n",
    "import os\n",
    "\n",
    "from pathlib import Path\n",
    "from typing import Tuple\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from litgpt import Tokenizer\n",
    "from litgpt.packed_dataset import CombinedDataset, PackedDataset\n",
    "\n",
    "from litgpt.model import Config\n",
    "\n",
    "# System settings\n",
    "model_name = \"tiny-llama-1.1b\"\n",
    "name = \"lit-tiny-llama-1.1b\"\n",
    "out_dir = Path(os.getenv(\"LIGHTNING_ARTIFACTS_DIR\", \"out\")) / name\n",
    "logger_name = \"tensorboard\"\n",
    "devices = torch.cuda.device_count() or 1\n",
    "\n",
    "# Hyperparameters\n",
    "global_batch_size = 512\n",
    "learning_rate = 4e-4\n",
    "micro_batch_size = 4\n",
    "max_tokens = int(3e12)  # 3 trillion\n",
    "warmup_steps = 2000\n",
    "log_step_interval = 1\n",
    "eval_iters = 100\n",
    "save_step_interval = 1000\n",
    "eval_step_interval = 1000\n",
    "\n",
    "weight_decay = 1e-1\n",
    "beta1 = 0.9\n",
    "beta2 = 0.95\n",
    "grad_clip = 1.0\n",
    "decay_lr = True\n",
    "min_lr = 4e-5\n",
    "\n",
    "batch_size = global_batch_size // devices\n",
    "gradient_accumulation_iters = batch_size // micro_batch_size\n",
    "assert gradient_accumulation_iters > 0\n",
    "warmup_iters = warmup_steps * gradient_accumulation_iters\n",
    "log_iter_interval = log_step_interval * gradient_accumulation_iters\n",
    "\n",
    "\n",
    "hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith(\"_\")}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataloaders(batch_size: int, block_size: int, num_workers: int = 8) -> Tuple[DataLoader, DataLoader]:\n",
    "\n",
    "\n",
    "    # data_dir = Path(\"/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/spj_star_combined_sample\")\n",
    "    # prefix = \"train_slimpajama\"\n",
    "    # data_dir = Path(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_realnewslike\")\n",
    "    # prefix = \"data\"\n",
    "    # prefix = \"data-000000-of-000002_0000000000\"\n",
    "    # data_dir = Path(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_rpjv2_sample10b\")\n",
    "    # prefix = \"train-000000-of-000128_0000000000\"\n",
    "    # data_dir = Path(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_wikitext103\")\n",
    "    # prefix = \"validation\"\n",
    "    data_dir = Path(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_openorca_tinyllama\")\n",
    "    # prefix = \"train-000000-of-000128_0000000000\"\n",
    "    # prefix = \"train-meta-cot_000000-of-000128_0000000009\"\n",
    "    # prefix = \"train-meta-\"\n",
    "    # filenames = list(data_dir.glob(f\"{prefix}*\"))\n",
    "    filenames = [\n",
    "        data_dir / \"train-meta-cot_000000-of-000128_0000000006.bin\",\n",
    "        data_dir / \"train-meta-cot_000000-of-000128_0000000007.bin\",\n",
    "        data_dir / \"train-meta-cot_000000-of-000128_0000000008.bin\",\n",
    "        data_dir / \"train-meta-cot_000000-of-000128_0000000009.bin\",\n",
    "    ]\n",
    "    print(filenames)\n",
    "    if not filenames:\n",
    "        raise FileNotFoundError(\n",
    "            f\"No files found at {str(data_dir)} with prefix {prefix}. Did you forget to run `prepare_redpajama.py`?\"\n",
    "        )\n",
    "    dataset = PackedDataset(\n",
    "        filenames,\n",
    "        n_chunks=4,\n",
    "        block_size=block_size,\n",
    "        shuffle=False,\n",
    "        seed=1234,\n",
    "        num_processes=1,\n",
    "        process_rank=0,\n",
    "    )\n",
    "    \n",
    "    combined_dataset = CombinedDataset(datasets=[dataset], seed=1234, weights=(1.0,))\n",
    "    train_dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n",
    "    return train_dataloader, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[PosixPath('/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_openorca_tinyllama/train-meta-cot_000000-of-000128_0000000006.bin'), PosixPath('/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_openorca_tinyllama/train-meta-cot_000000-of-000128_0000000007.bin'), PosixPath('/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_openorca_tinyllama/train-meta-cot_000000-of-000128_0000000008.bin'), PosixPath('/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/output/data/test_prepare_openorca_tinyllama/train-meta-cot_000000-of-000128_0000000009.bin')]\n"
     ]
    }
   ],
   "source": [
    "config = Config.from_name(model_name)\n",
    "\n",
    "train_dataloader, val_dataloader = create_dataloaders(batch_size=micro_batch_size, block_size=config.block_size)\n",
    "# train_dataloader, val_dataloader = create_dataloaders(batch_size=micro_batch_size, block_size=16)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  382,  1218,   385,  ...,   372,   338,   451],\n",
       "        [ 1950,   304,  2649,  ...,   278, 20051,   322],\n",
       "        [ 2274,   278,  3030,  ...,  4628,   868,   384],\n",
       "        [29892,   727,   526,  ...,  2367,   366,   263]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# iter(train_dataloader)\n",
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  382,  1218,   385,  ...,   372,   338,   451],\n",
      "        [ 1950,   304,  2649,  ...,   278, 20051,   322],\n",
      "        [ 2274,   278,  3030,  ...,  4628,   868,   384],\n",
      "        [29892,   727,   526,  ...,  2367,   366,   263]])\n",
      "tensor([[ 1139, 29889,  3575,  ...,    13, 29896, 29889],\n",
      "        [ 6936,   750,  2211,  ..., 14981,  5232,  9401],\n",
      "        [  304,   916,   564,  ...,    13,  1576,  2441],\n",
      "        [10541,  4083, 29892,  ..., 29941, 29955, 29906]])\n",
      "tensor([[  313, 12692, 29897,  ..., 10090,   470,   825],\n",
      "        [  670,  7609,  1080,  ...,   278,  5995, 14661],\n",
      "        [29889,    13, 29941,  ...,  1950,   393,  1183],\n",
      "        [ 1033,   367,   385,  ...,  1284,  2472, 29889]])\n",
      "tensor([[ 9133,   680,   263,  ..., 29899,   450,  2183],\n",
      "        [ 3229,  8128,   263,  ...,  1213,    13, 29941],\n",
      "        [29889,  3831,   598,  ..., 20255,   393,  6911],\n",
      "        [ 2305,  1284,  2472,  ...,  1348,   937, 29889]])\n",
      "tensor([[ 9133,   680,   263,  ..., 29889,    13,    13],\n",
      "        [29941,   260, 29899,  ...,  4459,   766, 29887],\n",
      "        [  504,   975,  1554,  ...,   694,   901,  1213],\n",
      "        [   13,  3624,   278,  ...,   937,   856,    13]])\n",
      "tensor([[ 6816, 29901,    13,  ...,   338, 29871, 29896],\n",
      "        [29906,  2440,  2030,  ..., 29892, 15226, 29892],\n",
      "        [  322, 13706, 29897,  ...,   714,  1867,   943],\n",
      "        [29889,    13,    13,  ..., 12080,  1549, 26602]])\n",
      "tensor([[  896,  1016, 29915,  ..., 10008,   408,   263],\n",
      "        [ 1513,  1121,   310,  ...,  4441,   552,  1372],\n",
      "        [29892, 25391, 27091,  ...,   911,   278, 20051],\n",
      "        [  297,  8220,   304,  ...,  3271, 29892,  5517]])\n",
      "tensor([[  408,   385,  4218,  ..., 29973,    13,    13],\n",
      "        [14448, 29871, 29896,  ..., 29899,  4874,    13],\n",
      "        [29899,   372,   338,  ...,   920,  1784,  5837],\n",
      "        [  508,   366, 12949,  ...,  2998,   408,   350]])\n",
      "tensor([[  896,  4953, 19710,  ..., 29889,     1,   887],\n",
      "        [  526,   385,   319,  ...,   920,  1784,  6199],\n",
      "        [  372,   674,  2125,  ...,   313, 29875, 29889],\n",
      "        [29872,  1696, 29871,  ...,  1234, 29892,   591]])\n",
      "tensor([[  788,   278,  2323,  ...,   278,  1855, 29885],\n",
      "        [  310,  1855,  4695,  ..., 13173,  1234,   577],\n",
      "        [ 1404,  1016, 30010,  ...,   526,   376, 11316],\n",
      "        [ 8277,  1213,    13,  ...,   338,   304,  1234]])\n",
      "tensor([[  408, 10847,  3730,  ...,  4821,   318,   561],\n",
      "        [  453, 29892,   540,  ...,   277, 29892,   372],\n",
      "        [29915, 29879,  4100,  ..., 29946, 29889,  3831],\n",
      "        [  598,  1716,  9506,  ...,   313, 29933, 29897]])\n",
      "tensor([[ 1781, 13386, 29888,  ...,   263,  9004,   362],\n",
      "        [29892,   470,  3763,  ...,   319, 29902, 20255],\n",
      "        [  393,  6911,  2305,  ...,   338,  2869,   385],\n",
      "        [ 8034, 21083,   322,  ..., 13589,   362,   338]])\n",
      "tensor([[18541,  3814,  4225,  ..., 29899,   974, 29899],\n",
      "        [ 3200,  8802,  2264,  ..., 29906, 29889, 13355],\n",
      "        [ 1598,   278, 20051,  ...,   310,   263,   304],\n",
      "        [29891,   393,   540,  ...,   450,  7601,  8569]])\n",
      "tensor([[  310,   289,   638,  ..., 26616,  6025,  3277],\n",
      "        [  338,  7424,   304,  ..., 29933,  1463,   373],\n",
      "        [  278,  2183, 10541,  ...,  3646,  4086,   338],\n",
      "        [19182,   470,  3033,  ...,  1139,   338,   278]])\n",
      "tensor([[15332,   310,  2305,  ...,     2,     2,     2],\n",
      "        [    2,     2,     2,  ...,     2,     2,     2],\n",
      "        [    2,     2,     2,  ...,     2,     2,     2],\n",
      "        [    2,     2,     2,  ...,     2,     2,     2]])\n"
     ]
    }
   ],
   "source": [
    "tokenizer = Tokenizer(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/input/models/tinyllama_tokenizer\")\n",
    "# tokenizer = Tokenizer(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/input/models/meta-llama/Llama-2-7b-hf\")\n",
    "# tokenizer = Tokenizer(\"/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-1/llm-pretraining-root/input/models/llama2_tokenizer_no_sp_model\")\n",
    "\n",
    "for i,batch in enumerate(train_dataloader):\n",
    "    # print(i)\n",
    "    print(batch)\n",
    "    # for sequence in batch:\n",
    "    #     print(tokenizer.decode(sequence))\n",
    "        # print(tokenizer.processor.decode(sequence.tolist(), skip_special_tokens=False))\n",
    "    if i == 14:\n",
    "        break\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1484"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.size(0) * batch.size(1)\n",
    "\n",
    "(batch != 2).sum()\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "# batch = batch.reshape(-1)\n",
    "\n",
    "# np.argwhere((batch != 2)).squeeze().shape[1]\n",
    "torch.argwhere((batch != 2)).squeeze().shape[0]\n",
    "# valid_toks = (batch != 2)\n",
    "# cts = []\n",
    "# for arr in valid_toks:\n",
    "#     if arr.any():\n",
    "#         cts.append(torch.argwhere(arr).squeeze()[-1].item() + 1)\n",
    "# cts\n",
    "# tokens_in_chunk = last_non_sep_idx + 1  # +1 for zero-indexing\n",
    "# batch.size(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'and other ventures . Author Walter Tevis denied for the rest of his life that Wanderone had played any role in the creation of the character . Other players would claim , with greater or lesser degrees of credibility , to have served as models for Fast Eddie , including Ronnie Allen , Ed Taylor , Ed Parker , and Eddie Pelkey . \\n\\n\\n\\n</s></s></s></s>'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "arr = np.array([  322,   916,  9712,  1973,   869, 13361, 10705,  1920,  1730,\n",
    "       17935,   363,   278,  1791,   310,   670,  2834,   393,   399,\n",
    "        3825,   650,   750,  5318,   738,  6297,   297,   278, 11265,\n",
    "         310,   278,  2931,   869,  5901, 10769,   723,  5995,  1919,\n",
    "         411,  7621,   470,  3109,   261, 14496,   310,  6625,  4127,\n",
    "        1919,   304,   505,  6766,   408,  4733,   363, 23786, 20861,\n",
    "         347,  1919,  3704, 11546,  2786, 16092,  1919,  2155, 12537,\n",
    "        1919,  2155, 24239,  1919,   322, 20861,   347, 15549,  1989,\n",
    "         869, 29871, 13,13,13,13, 2,2,2,2,])\n",
    "        #  869, 29871,    13,     2])\n",
    "\n",
    "# tokenizer.decode(arr)\n",
    "tokenizer.processor.decode(arr.tolist(), skip_special_tokens=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
