{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, LNTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Hyper-parameters\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = LNTuningConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    ")\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 5e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ],
   "id": "ed72410870ca5906"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Load and Process Dataset for LM Training",
   "id": "80b0135f8ede6151"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ],
   "id": "2169f0efc0735728"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "                max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ],
   "id": "43c7ff2c0e528a05"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "                max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ],
   "id": "2a44a74ad7cc3fd0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# show dataset size\n",
    "len(test_dataloader)"
   ],
   "id": "6c515848187a1714"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Train the LM with LNTuning\n",
    "1. Create the base LM.\n",
    "2. Only activate the LayerNorm layers in the LM for training.\n",
    "3. Train the LM on the training dataset."
   ],
   "id": "4fc05d29cb43ac0f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 1. creating the base LM\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "# 2. Only activate the LayerNorm layers in the Attention blocks in the LM for training\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ],
   "id": "8e9b5f7f26236d66"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# setup the optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ],
   "id": "d0cb00f178999738"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# 3. train the LM on the training dataset\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ],
   "id": "b678d298aec4d753"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Test the LM",
   "id": "39209a67fa821fa1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.eval()\n",
    "i = 33\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ],
   "id": "bdccc801e0dc7c5a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Save the trainable LM weights (LayerNorm layers)\n",
    "You can push model to hub or save model locally. \n",
    "\n",
    "- Option1: Push the model to Hugging Face Hub:\n",
    "\n",
    "    ```python\n",
    "    model.push_to_hub(\n",
    "        f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
    "        token = \"hf_...\"\n",
    "    )\n",
    "    ```\n",
    "    token (`bool` or `str`, *optional*):\n",
    "        `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
    "        when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
    "        is not specified.\n",
    "        Or you can get your token from https://huggingface.co/settings/token\n",
    "    ```\n",
    "- Option2: Save model locally:\n",
    "\n",
    "    ```python\n",
    "    peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
    "    model.save_pretrained(peft_model_id)\n",
    "    ```"
   ],
   "id": "467eff428b6a6de0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "model.save_pretrained(peft_model_id)"
   ],
   "id": "25028b3702db8f7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Test the LM using LNTuning loaded from saved weights\n",
    "1. load the LNTuning configuration\n",
    "2. load the base LM\n",
    "3. merge the LNTuning weights into the base LM using the PEFT config"
   ],
   "id": "7ec340cc987a8adf"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "# load the LNTuning config\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "# load the base LM\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "# merge LNTuning weights into the base LM\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ],
   "id": "b33be9fe75d345b2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ],
   "id": "b8c226998192780c"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
