{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
    "from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"t5-large\"\n",
    "tokenizer_name_or_path = \"t5-large\"\n",
    "\n",
    "checkpoint_name = \"financial_sentiment_analysis_prompt_tuning_v1.pt\"\n",
    "text_column = \"sentence\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 128\n",
    "lr = 1\n",
    "num_epochs = 5\n",
    "batch_size = 8"
   ],
   "id": "185c4eb438364c5f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# creating model\n",
    "peft_config = PromptTuningConfig(\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=PromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=20,\n",
    "    prompt_tuning_init_text=\"What is the sentiment of this article?\\n\",\n",
    "    inference_mode=False,\n",
    "    tokenizer_name_or_path=model_name_or_path,\n",
    ")\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model"
   ],
   "id": "9c410e426c4883f2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# loading dataset\n",
    "dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n",
    "dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n",
    "dataset[\"validation\"] = dataset[\"test\"]\n",
    "del dataset[\"test\"]\n",
    "\n",
    "classes = dataset[\"train\"].features[\"label\"].names\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "\n",
    "dataset[\"train\"][0]"
   ],
   "id": "aae1219f62096ee"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = examples[text_column]\n",
    "    targets = examples[label_column]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = tokenizer(\n",
    "        targets, max_length=target_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    model_inputs[\"labels\"] = labels\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"validation\"]\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ],
   "id": "ac6326b50a5bce0f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ],
   "id": "b1717837587fa6a1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ],
   "id": "1850aa016561ae0d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# print accuracy\n",
    "correct = 0\n",
    "total = 0\n",
    "for pred, true in zip(eval_preds, dataset[\"validation\"][\"text_label\"]):\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=} % on the evaluation dataset\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{dataset['validation']['text_label'][:10]=}\")"
   ],
   "id": "d19874781da6b284"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ],
   "id": "3992610302b6fa84"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ],
   "id": "4d36d9f1750bcb17"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ],
   "id": "6d71f7eaedda5697"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.eval()\n",
    "i = 107\n",
    "input_ids = tokenizer(dataset[\"validation\"][text_column][i], return_tensors=\"pt\").input_ids\n",
    "print(dataset[\"validation\"][text_column][i])\n",
    "print(input_ids)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=input_ids, max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ],
   "id": "c31873b295e18d73"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
