{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "import transformers\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "import datasets\n",
    "import wandb\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 24\n",
    "MAX_LENGTH = 128\n",
    "\n",
    "data = datasets.load_dataset(\"c4\", \"en\", split=\"train\", streaming=True)\n",
    "data = data.shuffle(seed=42)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n",
    "\n",
    "def preprocess_batched(batch):\n",
    "    batch = tokenizer(\n",
    "        batch[\"text\"],\n",
    "        max_length=MAX_LENGTH,\n",
    "        truncation=True,\n",
    "        padding=\"max_length\",\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "    return batch\n",
    "\n",
    "data_mapped = data.map(preprocess_batched, batched=True, batch_size=1000, remove_columns=[\"text\", \"timestamp\", \"url\"])\n",
    "\n",
    "def collate_fn(batch_list):\n",
    "    batch = {\n",
    "        \"input_ids\": torch.stack([example[\"input_ids\"] for example in batch_list]),\n",
    "        \"attention_mask\": torch.stack([example[\"attention_mask\"] for example in batch_list]),\n",
    "    }\n",
    "    return batch\n",
    "\n",
    "def batch_fn(dataset, batch_size):\n",
    "    batch = []\n",
    "    for example in dataset:\n",
    "        batch.append(example)\n",
    "        if len(batch) == batch_size:\n",
    "            batch = collate_fn(batch)\n",
    "            yield batch\n",
    "            batch = []\n",
    "    if len(batch) > 0:\n",
    "        yield batch\n",
    "\n",
    "data_mapped.batch = lambda batch_size: batch_fn(data_mapped, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_PEFT = True\n",
    "TRAIN_LN = True\n",
    "NUM_TRAINING_STEPS = 10_000\n",
    "\n",
    "device = \"cuda:1\"\n",
    "\n",
    "model_config = AutoConfig.from_pretrained(\"gpt2-large\")\n",
    "model = AutoModelForCausalLM.from_config(model_config)\n",
    "\n",
    "if USE_PEFT:\n",
    "    peft_config = LoraConfig(\n",
    "        task_type=TaskType.CAUSAL_LM,\n",
    "        inference_mode=False,\n",
    "        r=8,\n",
    "        lora_alpha=32,\n",
    "        lora_dropout=0.1,\n",
    "    )\n",
    "\n",
    "    model = get_peft_model(peft_config, model)\n",
    "\n",
    "    for name, param in model.named_parameters():\n",
    "        if TRAIN_LN and \"ln_\" in name:\n",
    "            param.requires_grad = True\n",
    "        if \"lm_head\" in name:\n",
    "            param.requires_grad = True\n",
    "        if \"transformer.wte\" in name:\n",
    "            param.requires_grad = True\n",
    "        if \"transformer.wpe\" in name:\n",
    "            param.requires_grad = True\n",
    "\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "n_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "p_trainable_params = n_trainable_params / n_total_params\n",
    "\n",
    "trainable_params = (p for p in model.parameters() if p.requires_grad)\n",
    "trainable_params_names = [name for name, p in model.named_parameters() if p.requires_grad]\n",
    "\n",
    "optimizer = torch.optim.Adam(trainable_params, lr=1e-4)\n",
    "scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1_000, num_training_steps=NUM_TRAINING_STEPS)\n",
    "\n",
    "_config = {\n",
    "    \"using_peft\": USE_PEFT,\n",
    "    \"layer_norm_trainable\": TRAIN_LN,\n",
    "    \"peft_config\": peft_config.to_dict(),\n",
    "    \"total_params\": n_total_params,\n",
    "    \"trainable_params\": n_trainable_params,\n",
    "    \"percent_trainable_params\": p_trainable_params,\n",
    "    \"name_trainable_params\": trainable_params_names,\n",
    "    \"dataset\": \"c4\",\n",
    "    \"batch_size\": BATCH_SIZE,\n",
    "    \"max_length\": MAX_LENGTH,\n",
    "    \"model\": model_config.to_dict(),\n",
    "    \"scheduler\": \"linear\",\n",
    "    \"device\": str(device),\n",
    "}\n",
    "\n",
    "wandb.init(project=\"peft_pretraining\", config=_config)\n",
    "pbar = tqdm(total=NUM_TRAINING_STEPS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.base_model.transformer.wte.weight.requires_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1):\n",
    "    data_mapped.set_epoch(epoch)\n",
    "    for batch in data_mapped.batch(batch_size=BATCH_SIZE):\n",
    "        pbar.update(1)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        labels = batch[\"input_ids\"].clone()\n",
    "        labels[labels == 0] = -100\n",
    "\n",
    "        loss = model(**batch, labels=labels).loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "\n",
    "        lr = scheduler.get_last_lr()[0]\n",
    "        wandb.log({\n",
    "            \"loss\": loss.item(),\n",
    "            \"lr\": lr,\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
