{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f93b7d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSeq2SeqLM\n",
    "from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/mt0-large\"\n",
    "tokenizer_name_or_path = \"bigscience/mt0-large\"\n",
    "\n",
    "checkpoint_name = \"financial_sentiment_analysis_lora_v1.pt\"\n",
    "text_column = \"sentence\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 128\n",
    "lr = 1e-3\n",
    "num_epochs = 3\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d0850ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# creating model\n",
    "peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4ee2babf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset financial_phrasebank (/home/sourab/.cache/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3403bf3d718042018b0531848cc30209",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3d5c45e3776469f9560b6eaa9346f8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e9736f26e9aa450b8d65f95c0b9c81cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'sentence': \"The 10,000-odd square metre plot that Stockmann has bought for the Nevsky Center shopping center is located on Nevsky Prospect , St Petersburg 's high street , next to the Vosstaniya Square underground station , in the immediate vicinity of Moscow Station .\",\n",
       " 'label': 1,\n",
       " 'text_label': 'neutral'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# loading dataset\n",
    "dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n",
    "dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n",
    "dataset[\"validation\"] = dataset[\"test\"]\n",
    "del dataset[\"test\"]\n",
    "\n",
    "classes = dataset[\"train\"].features[\"label\"].names\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "adf9608c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c460989d4ab24e3f97d81ef040b1d1b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/3 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1acc389b08b94f8a87900b9fbdbccce4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = examples[text_column]\n",
    "    targets = examples[label_column]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = tokenizer(targets, max_length=3, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    model_inputs[\"labels\"] = labels\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"validation\"]\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f733a3c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b3a4090",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 255/255 [02:21<00:00,  1.81it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:07<00:00,  4.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(14.6341, device='cuda:0') train_epoch_loss=tensor(2.6834, device='cuda:0') eval_ppl=tensor(1.0057, device='cuda:0') eval_epoch_loss=tensor(0.0057, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 255/255 [02:00<00:00,  2.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:05<00:00,  5.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(1.7576, device='cuda:0') train_epoch_loss=tensor(0.5640, device='cuda:0') eval_ppl=tensor(1.0052, device='cuda:0') eval_epoch_loss=tensor(0.0052, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 255/255 [01:33<00:00,  2.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:04<00:00,  6.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(1.3830, device='cuda:0') train_epoch_loss=tensor(0.3243, device='cuda:0') eval_ppl=tensor(1.0035, device='cuda:0') eval_epoch_loss=tensor(0.0035, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6cafa67b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy=97.3568281938326 % on the evaluation dataset\n",
      "eval_preds[:10]=['neutral', 'neutral', 'neutral', 'positive', 'neutral', 'positive', 'positive', 'neutral', 'neutral', 'neutral']\n",
      "dataset['validation']['text_label'][:10]=['neutral', 'neutral', 'neutral', 'positive', 'neutral', 'positive', 'positive', 'neutral', 'neutral', 'neutral']\n"
     ]
    }
   ],
   "source": [
    "# print accuracy\n",
    "correct = 0\n",
    "total = 0\n",
    "for pred, true in zip(eval_preds, dataset[\"validation\"][\"text_label\"]):\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=} % on the evaluation dataset\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{dataset['validation']['text_label'][:10]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a8de6005",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bd20cd4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9,2M\tbigscience/mt0-large_LORA_SEQ_2_SEQ_LM/adapter_model.bin\r\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76c2fc29",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "37d712ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "- Demand for fireplace products was lower than expected , especially in Germany .\n",
      "{'input_ids': tensor([[  259,   264,   259, 82903,   332,  1090, 10040, 10371,   639,   259,\n",
      "         19540,  2421,   259, 25505,   259,   261,   259, 21230,   281, 17052,\n",
      "           259,   260,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[    0,   259, 32588,     1]])\n",
      "['negative']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 13\n",
    "inputs = tokenizer(dataset[\"validation\"][text_column][i], return_tensors=\"pt\")\n",
    "print(dataset[\"validation\"][text_column][i])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=inputs[\"input_ids\"], max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66c65ea4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e71f78",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
