{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "71fbfca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, PrefixTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=30)\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 3e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1a3648b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset raft (/home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "56d9908a2c8944b484348cc46b16a261",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-20a7622c86d80cdf.arrow\n",
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-5f1431311da05803.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Unlabeled', 'complaint', 'no complaint']\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 50\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 3399\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Tweet text': '@HMRCcustomers No this is my first job',\n",
       " 'ID': 0,\n",
       " 'Label': 2,\n",
       " 'text_label': 'no complaint'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe12d4d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a0e3242324842fb941950df38b459fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "133df817b7b9468cabd5353d4d2b675b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Running tokenizer on dataset:   0%|          | 0/4 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.pad_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "641b21fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "accc5012",
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(train_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "218df807",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "425"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d1fedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a773e092",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1474560 || all params: 560689152 || trainable%: 0.26299064191632515\n"
     ]
    }
   ],
   "source": [
    "# creating model\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bd419634",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1474560 || all params: 560689152 || trainable%: 0.26299064191632515\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22822901",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "023cb942",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PrefixTuningConfig(peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, base_model_name_or_path='bigscience/bloomz-560m', task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=30, token_dim=1024, num_transformer_submodules=1, num_attention_heads=16, num_layers=24, encoder_hidden_size=1024, prefix_projection=False)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.peft_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b2f91568",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model\n",
    "# optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e4fb69fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.79it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(1.8325e+09, device='cuda:0') train_epoch_loss=tensor(21.3289, device='cuda:0') eval_ppl=tensor(2713.4180, device='cuda:0') eval_epoch_loss=tensor(7.9060, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(341.0600, device='cuda:0') train_epoch_loss=tensor(5.8321, device='cuda:0') eval_ppl=tensor(80.8206, device='cuda:0') eval_epoch_loss=tensor(4.3922, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(59.8778, device='cuda:0') train_epoch_loss=tensor(4.0923, device='cuda:0') eval_ppl=tensor(34.4593, device='cuda:0') eval_epoch_loss=tensor(3.5398, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3: train_ppl=tensor(22.3307, device='cuda:0') train_epoch_loss=tensor(3.1060, device='cuda:0') eval_ppl=tensor(12.5947, device='cuda:0') eval_epoch_loss=tensor(2.5333, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4: train_ppl=tensor(9.1697, device='cuda:0') train_epoch_loss=tensor(2.2159, device='cuda:0') eval_ppl=tensor(4.5289, device='cuda:0') eval_epoch_loss=tensor(1.5105, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=5: train_ppl=tensor(3.0172, device='cuda:0') train_epoch_loss=tensor(1.1043, device='cuda:0') eval_ppl=tensor(1.8092, device='cuda:0') eval_epoch_loss=tensor(0.5929, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=6: train_ppl=tensor(1.4885, device='cuda:0') train_epoch_loss=tensor(0.3978, device='cuda:0') eval_ppl=tensor(1.4449, device='cuda:0') eval_epoch_loss=tensor(0.3680, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=7: train_ppl=tensor(1.2967, device='cuda:0') train_epoch_loss=tensor(0.2598, device='cuda:0') eval_ppl=tensor(1.1587, device='cuda:0') eval_epoch_loss=tensor(0.1473, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=8: train_ppl=tensor(1.1305, device='cuda:0') train_epoch_loss=tensor(0.1227, device='cuda:0') eval_ppl=tensor(1.0874, device='cuda:0') eval_epoch_loss=tensor(0.0838, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=9: train_ppl=tensor(1.1608, device='cuda:0') train_epoch_loss=tensor(0.1491, device='cuda:0') eval_ppl=tensor(1.1461, device='cuda:0') eval_epoch_loss=tensor(0.1364, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=10: train_ppl=tensor(1.3172, device='cuda:0') train_epoch_loss=tensor(0.2755, device='cuda:0') eval_ppl=tensor(1.1320, device='cuda:0') eval_epoch_loss=tensor(0.1240, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=11: train_ppl=tensor(1.1437, device='cuda:0') train_epoch_loss=tensor(0.1343, device='cuda:0') eval_ppl=tensor(1.0676, device='cuda:0') eval_epoch_loss=tensor(0.0654, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=12: train_ppl=tensor(1.0651, device='cuda:0') train_epoch_loss=tensor(0.0630, device='cuda:0') eval_ppl=tensor(1.0735, device='cuda:0') eval_epoch_loss=tensor(0.0710, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=13: train_ppl=tensor(1.0607, device='cuda:0') train_epoch_loss=tensor(0.0589, device='cuda:0') eval_ppl=tensor(1.0399, device='cuda:0') eval_epoch_loss=tensor(0.0391, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=14: train_ppl=tensor(1.0351, device='cuda:0') train_epoch_loss=tensor(0.0345, device='cuda:0') eval_ppl=tensor(1.0260, device='cuda:0') eval_epoch_loss=tensor(0.0257, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=15: train_ppl=tensor(1.0217, device='cuda:0') train_epoch_loss=tensor(0.0215, device='cuda:0') eval_ppl=tensor(1.0168, device='cuda:0') eval_epoch_loss=tensor(0.0167, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=16: train_ppl=tensor(1.0152, device='cuda:0') train_epoch_loss=tensor(0.0151, device='cuda:0') eval_ppl=tensor(1.0117, device='cuda:0') eval_epoch_loss=tensor(0.0116, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=17: train_ppl=tensor(1.0102, device='cuda:0') train_epoch_loss=tensor(0.0101, device='cuda:0') eval_ppl=tensor(1.0088, device='cuda:0') eval_epoch_loss=tensor(0.0088, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.29it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=18: train_ppl=tensor(1.0083, device='cuda:0') train_epoch_loss=tensor(0.0083, device='cuda:0') eval_ppl=tensor(1.0073, device='cuda:0') eval_epoch_loss=tensor(0.0073, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=19: train_ppl=tensor(1.0070, device='cuda:0') train_epoch_loss=tensor(0.0070, device='cuda:0') eval_ppl=tensor(1.0064, device='cuda:0') eval_epoch_loss=tensor(0.0063, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=20: train_ppl=tensor(1.0059, device='cuda:0') train_epoch_loss=tensor(0.0059, device='cuda:0') eval_ppl=tensor(1.0057, device='cuda:0') eval_epoch_loss=tensor(0.0057, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=21: train_ppl=tensor(1.0056, device='cuda:0') train_epoch_loss=tensor(0.0056, device='cuda:0') eval_ppl=tensor(1.0052, device='cuda:0') eval_epoch_loss=tensor(0.0052, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=22: train_ppl=tensor(1.0050, device='cuda:0') train_epoch_loss=tensor(0.0050, device='cuda:0') eval_ppl=tensor(1.0049, device='cuda:0') eval_epoch_loss=tensor(0.0049, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.39it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.44it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=23: train_ppl=tensor(1.0049, device='cuda:0') train_epoch_loss=tensor(0.0049, device='cuda:0') eval_ppl=tensor(1.0045, device='cuda:0') eval_epoch_loss=tensor(0.0045, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=24: train_ppl=tensor(1.0043, device='cuda:0') train_epoch_loss=tensor(0.0043, device='cuda:0') eval_ppl=tensor(1.0043, device='cuda:0') eval_epoch_loss=tensor(0.0043, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=25: train_ppl=tensor(1.0042, device='cuda:0') train_epoch_loss=tensor(0.0042, device='cuda:0') eval_ppl=tensor(1.0040, device='cuda:0') eval_epoch_loss=tensor(0.0040, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=26: train_ppl=tensor(1.0039, device='cuda:0') train_epoch_loss=tensor(0.0039, device='cuda:0') eval_ppl=tensor(1.0039, device='cuda:0') eval_epoch_loss=tensor(0.0039, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=27: train_ppl=tensor(1.0038, device='cuda:0') train_epoch_loss=tensor(0.0038, device='cuda:0') eval_ppl=tensor(1.0037, device='cuda:0') eval_epoch_loss=tensor(0.0037, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=28: train_ppl=tensor(1.0036, device='cuda:0') train_epoch_loss=tensor(0.0036, device='cuda:0') eval_ppl=tensor(1.0035, device='cuda:0') eval_epoch_loss=tensor(0.0035, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.53it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=29: train_ppl=tensor(1.0034, device='cuda:0') train_epoch_loss=tensor(0.0034, device='cuda:0') eval_ppl=tensor(1.0034, device='cuda:0') eval_epoch_loss=tensor(0.0034, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=30: train_ppl=tensor(1.0034, device='cuda:0') train_epoch_loss=tensor(0.0034, device='cuda:0') eval_ppl=tensor(1.0033, device='cuda:0') eval_epoch_loss=tensor(0.0033, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=31: train_ppl=tensor(1.0033, device='cuda:0') train_epoch_loss=tensor(0.0033, device='cuda:0') eval_ppl=tensor(1.0032, device='cuda:0') eval_epoch_loss=tensor(0.0032, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=32: train_ppl=tensor(1.0031, device='cuda:0') train_epoch_loss=tensor(0.0031, device='cuda:0') eval_ppl=tensor(1.0031, device='cuda:0') eval_epoch_loss=tensor(0.0031, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=33: train_ppl=tensor(1.0030, device='cuda:0') train_epoch_loss=tensor(0.0030, device='cuda:0') eval_ppl=tensor(1.0030, device='cuda:0') eval_epoch_loss=tensor(0.0030, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=34: train_ppl=tensor(1.0029, device='cuda:0') train_epoch_loss=tensor(0.0029, device='cuda:0') eval_ppl=tensor(1.0029, device='cuda:0') eval_epoch_loss=tensor(0.0029, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=35: train_ppl=tensor(1.0028, device='cuda:0') train_epoch_loss=tensor(0.0028, device='cuda:0') eval_ppl=tensor(1.0029, device='cuda:0') eval_epoch_loss=tensor(0.0029, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=36: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0028, device='cuda:0') eval_epoch_loss=tensor(0.0028, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=37: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0027, device='cuda:0') eval_epoch_loss=tensor(0.0027, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=38: train_ppl=tensor(1.0027, device='cuda:0') train_epoch_loss=tensor(0.0027, device='cuda:0') eval_ppl=tensor(1.0027, device='cuda:0') eval_epoch_loss=tensor(0.0027, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=39: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0025, device='cuda:0') eval_ppl=tensor(1.0026, device='cuda:0') eval_epoch_loss=tensor(0.0026, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=40: train_ppl=tensor(1.0026, device='cuda:0') train_epoch_loss=tensor(0.0026, device='cuda:0') eval_ppl=tensor(1.0026, device='cuda:0') eval_epoch_loss=tensor(0.0026, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=41: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0025, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=42: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=43: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0025, device='cuda:0') eval_epoch_loss=tensor(0.0025, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.44it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=44: train_ppl=tensor(1.0025, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=45: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=46: train_ppl=tensor(1.0024, device='cuda:0') train_epoch_loss=tensor(0.0024, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.42it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=47: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.40it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.40it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=48: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 11.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 21.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=49: train_ppl=tensor(1.0023, device='cuda:0') train_epoch_loss=tensor(0.0023, device='cuda:0') eval_ppl=tensor(1.0024, device='cuda:0') eval_epoch_loss=tensor(0.0024, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# training and evaluation\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "53752a7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game.\n",
      "{'input_ids': tensor([[227985,   5484,    915,  54078,   2566,   7782,  24502,   2632,   8989,\n",
      "            427,  36992,   2670, 140711,  21994,  10789,    530,  88399,    632,\n",
      "         183542,    368,  44799,     17,  29901,   5926,   7229,    861,  11596,\n",
      "            461,  78851,  14775,     17,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,  54078,   2566,   7782,  24502,   2632,   8989,\n",
      "            427,  36992,   2670, 140711,  21994,  10789,    530,  88399,    632,\n",
      "         183542,    368,  44799,     17,  29901,   5926,   7229,    861,  11596,\n",
      "            461,  78851,  14775,     17,  77658,    915,    210,  16449,   5952,\n",
      "              3]], device='cuda:0')\n",
      "[\"Tweet text : Hey @nytimes your link to cancel my subscription isn't working and nobody is answering the chat. Please don't play that kind of stupid game. Label : complaint\"]\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 16\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "24041ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "527eeaa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = f\"{peft_model_id}/adapter_model.bin\"\n",
    "!du -h $ckpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b19f5a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = f\"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a11a3768",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@greateranglia Ok thanks...\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210,   1936, 106863,      3]],\n",
      "       device='cuda:0')\n",
      "['Tweet text : @greateranglia Ok thanks... Label : no complaint']\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f890c951",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "463a41a2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c60c7a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
