{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/XXXX-36/manlis/anaconda3/envs/axonn/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import sys\n",
    "import time\n",
    "from pathlib import Path\n",
    "from typing import Optional, Tuple, Union\n",
    "import os\n",
    "\n",
    "import lightning as L\n",
    "import torch\n",
    "from lightning.fabric.loggers import CSVLogger\n",
    "from lightning.fabric.strategies import FSDPStrategy\n",
    "from lightning.fabric.utilities import ThroughputMonitor, measure_flops\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# support running without installing as a package\n",
    "wd = Path(os.path.dirname(os.path.abspath(\"__file__\"))).parent.resolve()\n",
    "sys.path.append(str(wd))\n",
    "\n",
    "from litgpt.model import GPT, Block, Config\n",
    "from litgpt.packed_dataset import CombinedDataset, PackedDataset\n",
    "from litgpt.utils import chunked_cross_entropy, estimate_flops, get_default_supported_precision, num_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"Llama-2-7b-hf\"\n",
    "name = \"redpajama\"\n",
    "out_dir = Path(\"out\") / name\n",
    "save_interval = 1000\n",
    "eval_interval = 1000\n",
    "eval_iters = 100\n",
    "log_interval = 1\n",
    "\n",
    "# Hyperparameters\n",
    "learning_rate = 6e-4\n",
    "batch_size = 125\n",
    "micro_batch_size = 6\n",
    "gradient_accumulation_steps = batch_size // micro_batch_size\n",
    "assert gradient_accumulation_steps > 0\n",
    "max_steps = 600000  # num_epochs * (epoch_size // micro_batch_size) // devices\n",
    "weight_decay = 1e-1\n",
    "beta1 = 0.9\n",
    "beta2 = 0.95\n",
    "grad_clip = 1.0\n",
    "decay_lr = True\n",
    "warmup_iters = 2000\n",
    "lr_decay_iters = max_steps\n",
    "min_lr = 6e-5\n",
    "\n",
    "\n",
    "# Data proportions from https://arxiv.org/pdf/2302.13971.pdf Table 1\n",
    "data_config = [\n",
    "    (\"arxiv\", 2.5),\n",
    "    (\"book\", 4.5),\n",
    "    (\"c4\", 15.0),\n",
    "    (\"cc\", 67.0),\n",
    "    (\"github\", 4.5),\n",
    "    (\"stackexchange\", 2.0),\n",
    "    (\"wikipedia\", 4.5),\n",
    "]\n",
    "\n",
    "hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith(\"_\")}\n",
    "logger = CSVLogger(\"out\", name, flush_logs_every_n_steps=log_interval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataloader(\n",
    "    batch_size: int, block_size: int, data_dir: Path, shuffle: bool = True, seed: int = 12345\n",
    ") -> DataLoader:\n",
    "    datasets = []\n",
    "    for prefix, _ in data_config:\n",
    "        filenames = list(data_dir.glob(f\"{prefix}*\"))\n",
    "        if not filenames:\n",
    "            raise FileNotFoundError(\n",
    "                f\"No files found at {str(data_dir)} with prefix {prefix}. Did you forget to run `prepare_redpajama.py`?\"\n",
    "            )\n",
    "        dataset = PackedDataset(\n",
    "            filenames,\n",
    "            n_chunks=4,\n",
    "            block_size=block_size,\n",
    "            shuffle=shuffle,\n",
    "            seed=seed,\n",
    "            num_processes=1,\n",
    "            process_rank=0,\n",
    "        )\n",
    "        datasets.append(dataset)\n",
    "\n",
    "    if not datasets:\n",
    "        raise RuntimeError(\n",
    "            f\"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset.\"\n",
    "        )\n",
    "\n",
    "    weights = [weight for _, weight in data_config]\n",
    "    sum_weights = sum(weights)\n",
    "    weights = [el / sum_weights for el in weights]\n",
    "\n",
    "    combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights)\n",
    "\n",
    "    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = Config.from_name(model_name)\n",
    "\n",
    "train_dataloader = create_dataloader(\n",
    "        batch_size=batch_size,\n",
    "        block_size=config.block_size+1,\n",
    "        data_dir=Path('/fs/XXXX-38-scratch/manlis/lit-gpt/data/lit-redpajama-sample'),\n",
    "        shuffle=True,\n",
    "        seed=42,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ltor_masks_and_position_ids(data,\n",
    "                                    eod_token,\n",
    "                                    reset_position_ids,\n",
    "                                    reset_attention_mask,\n",
    "                                    eod_mask_loss):\n",
    "    \"\"\"Build masks and position id for left to right model.\"\"\"\n",
    "\n",
    "    # Extract batch size and sequence length.\n",
    "    micro_batch_size, seq_length = data.size()\n",
    "\n",
    "    # Attention mask (lower triangular).\n",
    "    if reset_attention_mask:\n",
    "        att_mask_batch = micro_batch_size\n",
    "    else:\n",
    "        att_mask_batch = 1\n",
    "    attention_mask = torch.tril(torch.ones(\n",
    "        (att_mask_batch, seq_length, seq_length), device=data.device)).view(\n",
    "            att_mask_batch, 1, seq_length, seq_length)\n",
    "\n",
    "    # Loss mask.\n",
    "    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)\n",
    "    if eod_mask_loss:\n",
    "        loss_mask[data == eod_token] = 0.0\n",
    "\n",
    "    # Position ids.\n",
    "    position_ids = torch.arange(seq_length, dtype=torch.long,\n",
    "                                device=data.device)\n",
    "    position_ids = position_ids.unsqueeze(0).expand_as(data)\n",
    "    # We need to clone as the ids will be modifed based on batch index.\n",
    "    if reset_position_ids:\n",
    "        position_ids = position_ids.clone()\n",
    "\n",
    "    if reset_position_ids or reset_attention_mask:\n",
    "        # Loop through the batches:\n",
    "        for b in range(micro_batch_size):\n",
    "\n",
    "            # Find indecies where EOD token is.\n",
    "            eod_index = position_ids[b, data[b] == eod_token]\n",
    "            # Detach indecies from positions if going to modify positions.\n",
    "            if reset_position_ids:\n",
    "                eod_index = eod_index.clone()\n",
    "\n",
    "            # Loop through EOD indecies:\n",
    "            prev_index = 0\n",
    "            for j in range(eod_index.size()[0]):\n",
    "                i = eod_index[j]\n",
    "                # Mask attention loss.\n",
    "                if reset_attention_mask:\n",
    "                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0\n",
    "                # Reset positions.\n",
    "                if reset_position_ids:\n",
    "                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)\n",
    "                    prev_index = i + 1\n",
    "\n",
    "    # Convert attention mask to binary:\n",
    "    attention_mask = (attention_mask > 0.5)\n",
    "\n",
    "    return attention_mask, loss_mask, position_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([125, 4097])\n",
      "torch.Size([125, 1, 4097, 4097])\n",
      "tensor([[[ True, False, False,  ..., False, False, False],\n",
      "         [ True,  True, False,  ..., False, False, False],\n",
      "         [ True,  True,  True,  ..., False, False, False],\n",
      "         ...,\n",
      "         [False, False, False,  ...,  True, False, False],\n",
      "         [False, False, False,  ...,  True,  True, False],\n",
      "         [False, False, False,  ...,  True,  True,  True]]])\n"
     ]
    }
   ],
   "source": [
    "for batch in train_dataloader:\n",
    "    print(batch.shape)\n",
    "    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(\n",
    "        batch,\n",
    "        1,\n",
    "        True,\n",
    "        True,\n",
    "        True)\n",
    "    print(attention_mask.shape)\n",
    "    print(attention_mask[3])\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "axonn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
