{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 3e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ],
   "id": "23e06ed5224da939"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ],
   "id": "ed3b6676a4fd67e1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "                max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ],
   "id": "8365324dee8258f8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "                max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ],
   "id": "4ccd1f0b08d121dd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "next(iter(train_dataloader))",
   "id": "43ea984dd312e019"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "len(test_dataloader)",
   "id": "b6602b60afaeae0e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "next(iter(test_dataloader))",
   "id": "9145725d546c6b2c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# creating model\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ],
   "id": "ea5f754cb3288340"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.print_trainable_parameters()",
   "id": "c79b478206d46975"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model",
   "id": "eac2c366153fe308"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.peft_config",
   "id": "12f36913a8a22d9c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# model\n",
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ],
   "id": "b127ce8020c950ff"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ],
   "id": "73718718f99c2959"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.eval()\n",
    "i = 16\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ],
   "id": "826da56ef2654117"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "You can push model to hub or save model locally. \n",
    "\n",
    "- Option1: Pushing the model to Hugging Face Hub\n",
    "```python\n",
    "model.push_to_hub(\n",
    "    f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
    "    token = \"hf_...\"\n",
    ")\n",
    "```\n",
    "token (`bool` or `str`, *optional*):\n",
    "    `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
    "    when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
    "    is not specified.\n",
    "    Or you can get your token from https://huggingface.co/settings/token\n",
    "```\n",
    "- Or save model locally\n",
    "```python\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
    "model.save_pretrained(peft_model_id)\n",
    "```"
   ],
   "id": "62d8481d25ed2131"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "model.save_pretrained(peft_model_id)"
   ],
   "id": "c4f465de23806fe2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ],
   "id": "83273162deabaec"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ],
   "id": "4e59e8df65af7cb9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ],
   "id": "d815f4bcf26f0a9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "442be70fa2a06de0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "af0079678d153a68"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "a8e3404fbc0058d1"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
