{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5f93b7d1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:37:58.711225Z",
     "start_time": "2023-05-30T08:37:56.881307Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please run\n",
      "\n",
      "python -m bitsandbytes\n",
      "\n",
      " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "bin /udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so\n",
      "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...\n",
      "CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /udir/tschilla/anaconda3 did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('Europe/Paris')}\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/udir/tschilla/.cache/dotnet_bundle_extract')}\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('5002'), PosixPath('http'), PosixPath('//127.0.0.1')}\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('() {  ( alias;\\n eval ${which_declare} ) | /usr/bin/which --tty-only --read-alias --read-functions --show-tilde --show-dot $@\\n}')}\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//matplotlib_inline.backend_inline')}\n",
      "  warn(msg)\n",
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: Found duplicate ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] files: {PosixPath('/usr/local/cuda/lib64/libcudart.so.11.0'), PosixPath('/usr/local/cuda/lib64/libcudart.so')}.. We'll flip a coin and try one of these, in order to fail forward.\n",
      "Either way, this might cause trouble in the future:\n",
      "If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.\n",
      "  warn(msg)\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
    "from peft import  get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"t5-large\"\n",
    "tokenizer_name_or_path = \"t5-large\"\n",
    "\n",
    "checkpoint_name = \"financial_sentiment_analysis_prompt_tuning_v1.pt\"\n",
    "text_column = \"sentence\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 128\n",
    "lr = 1\n",
    "num_epochs = 5\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d0850ac",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:38:12.413984Z",
     "start_time": "2023-05-30T08:38:04.601042Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 40960 || all params: 737709056 || trainable%: 0.005552324411210698\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/udir/tschilla/anaconda3/envs/peft/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
      "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
      "- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.\n",
      "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
      "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "PeftModelForSeq2SeqLM(\n",
       "  (base_model): T5ForConditionalGeneration(\n",
       "    (shared): Embedding(32128, 1024)\n",
       "    (encoder): T5Stack(\n",
       "      (embed_tokens): Embedding(32128, 1024)\n",
       "      (block): ModuleList(\n",
       "        (0): T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (relative_attention_bias): Embedding(32, 16)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (1-23): 23 x T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layer_norm): T5LayerNorm()\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (decoder): T5Stack(\n",
       "      (embed_tokens): Embedding(32128, 1024)\n",
       "      (block): ModuleList(\n",
       "        (0): T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (relative_attention_bias): Embedding(32, 16)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerCrossAttention(\n",
       "              (EncDecAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (2): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (1-23): 23 x T5Block(\n",
       "          (layer): ModuleList(\n",
       "            (0): T5LayerSelfAttention(\n",
       "              (SelfAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (1): T5LayerCrossAttention(\n",
       "              (EncDecAttention): T5Attention(\n",
       "                (q): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (k): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (v): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "                (o): Linear(in_features=1024, out_features=1024, bias=False)\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (2): T5LayerFF(\n",
       "              (DenseReluDense): T5DenseActDense(\n",
       "                (wi): Linear(in_features=1024, out_features=4096, bias=False)\n",
       "                (wo): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "                (act): ReLU()\n",
       "              )\n",
       "              (layer_norm): T5LayerNorm()\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layer_norm): T5LayerNorm()\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=1024, out_features=32128, bias=False)\n",
       "  )\n",
       "  (prompt_encoder): ModuleDict(\n",
       "    (default): PromptEmbedding(\n",
       "      (embedding): Embedding(40, 1024)\n",
       "    )\n",
       "  )\n",
       "  (word_embeddings): Embedding(32128, 1024)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# creating model\n",
    "peft_config = PromptTuningConfig(\n",
    "    task_type=TaskType.SEQ_2_SEQ_LM,\n",
    "    prompt_tuning_init=PromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=20,\n",
    "    prompt_tuning_init_text=\"What is the sentiment of this article?\\n\",\n",
    "    inference_mode=False,\n",
    "    tokenizer_name_or_path=model_name_or_path\n",
    ")\n",
    "\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4ee2babf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:38:18.759143Z",
     "start_time": "2023-05-30T08:38:17.881621Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset financial_phrasebank (/data/proxem/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fb63f50cb7cb4f5aae10648ba74d6c4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/2037 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/227 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'sentence': '`` Lining stone sales were also good in the early autumn , and order books are strong to the end of the year .',\n",
       " 'label': 2,\n",
       " 'text_label': 'positive'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# loading dataset\n",
    "dataset = load_dataset(\"financial_phrasebank\", \"sentences_allagree\")\n",
    "dataset = dataset[\"train\"].train_test_split(test_size=0.1)\n",
    "dataset[\"validation\"] = dataset[\"test\"]\n",
    "del dataset[\"test\"]\n",
    "\n",
    "classes = dataset[\"train\"].features[\"label\"].names\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "adf9608c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:38:21.132266Z",
     "start_time": "2023-05-30T08:38:20.340722Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/2037 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/227 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    inputs = examples[text_column]\n",
    "    targets = examples[label_column]\n",
    "    model_inputs = tokenizer(inputs, max_length=max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = tokenizer(targets, max_length=target_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "    labels = labels[\"input_ids\"]\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    model_inputs[\"labels\"] = labels\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"validation\"]\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f733a3c6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:38:22.907922Z",
     "start_time": "2023-05-30T08:38:22.901057Z"
    }
   },
   "outputs": [],
   "source": [
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6b3a4090",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:29.409070Z",
     "start_time": "2023-05-30T08:38:50.102263Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:42<00:00,  6.05it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 14.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(8.0846, device='cuda:0') train_epoch_loss=tensor(2.0900, device='cuda:0') eval_ppl=tensor(1.3542, device='cuda:0') eval_epoch_loss=tensor(0.3032, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:41<00:00,  6.15it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 14.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(1.5088, device='cuda:0') train_epoch_loss=tensor(0.4113, device='cuda:0') eval_ppl=tensor(1.2692, device='cuda:0') eval_epoch_loss=tensor(0.2384, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:41<00:00,  6.18it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 14.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(1.5322, device='cuda:0') train_epoch_loss=tensor(0.4267, device='cuda:0') eval_ppl=tensor(1.2065, device='cuda:0') eval_epoch_loss=tensor(0.1877, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:41<00:00,  6.17it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 14.38it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3: train_ppl=tensor(1.4475, device='cuda:0') train_epoch_loss=tensor(0.3699, device='cuda:0') eval_ppl=tensor(1.2346, device='cuda:0') eval_epoch_loss=tensor(0.2107, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 255/255 [00:42<00:00,  5.94it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:02<00:00, 14.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4: train_ppl=tensor(1.3428, device='cuda:0') train_epoch_loss=tensor(0.2948, device='cuda:0') eval_ppl=tensor(1.2041, device='cuda:0') eval_epoch_loss=tensor(0.1857, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6cafa67b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:42.844671Z",
     "start_time": "2023-05-30T08:42:42.840447Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy=85.46255506607929 % on the evaluation dataset\n",
      "eval_preds[:10]=['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'positive', 'neutral', 'negative', 'neutral', 'positive']\n",
      "dataset['validation']['text_label'][:10]=['neutral', 'neutral', 'neutral', 'neutral', 'neutral', 'positive', 'neutral', 'negative', 'positive', 'neutral']\n"
     ]
    }
   ],
   "source": [
    "# print accuracy\n",
    "correct = 0\n",
    "total = 0\n",
    "for pred, true in zip(eval_preds, dataset[\"validation\"][\"text_label\"]):\n",
    "    if pred.strip() == true.strip():\n",
    "        correct += 1\n",
    "    total += 1\n",
    "accuracy = correct / total * 100\n",
    "print(f\"{accuracy=} % on the evaluation dataset\")\n",
    "print(f\"{eval_preds[:10]=}\")\n",
    "print(f\"{dataset['validation']['text_label'][:10]=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a8de6005",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:45.752765Z",
     "start_time": "2023-05-30T08:42:45.742397Z"
    }
   },
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bd20cd4c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:47.660873Z",
     "start_time": "2023-05-30T08:42:47.488293Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "164K\tt5-large_PROMPT_TUNING_SEQ_2_SEQ_LM/adapter_model.bin\r\n"
     ]
    }
   ],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "76c2fc29",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:56.721990Z",
     "start_time": "2023-05-30T08:42:49.060700Z"
    }
   },
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d997f1cc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-05-30T08:42:59.600916Z",
     "start_time": "2023-05-30T08:42:58.961468Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Danske Bank is Denmark 's largest bank with 3.5 million customers .\n",
      "tensor([[ 3039,  1050,  1925,    19, 18001,     3,    31,     7,  2015,  2137,\n",
      "            28,     3,  9285,   770,   722,     3,     5,     1]])\n",
      "tensor([[   0, 7163,    1]])\n",
      "['neutral']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 107\n",
    "input_ids = tokenizer(dataset[\"validation\"][text_column][i], return_tensors=\"pt\").input_ids\n",
    "print(dataset[\"validation\"][text_column][i])\n",
    "print(input_ids)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(input_ids=input_ids, max_new_tokens=10)\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peft",
   "language": "python",
   "name": "peft"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
