{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "71fbfca2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "from peft import get_peft_config, get_peft_model, LNTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import default_data_collator, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "\n",
    "# Hyper-parameters\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = LNTuningConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    ")\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 5e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a617882d",
   "metadata": {},
   "source": [
    "## Load and Process Dataset for LM Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1a3648b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Unlabeled', 'complaint', 'no complaint']\n",
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 50\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['Tweet text', 'ID', 'Label', 'text_label'],\n",
      "        num_rows: 3399\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Tweet text': '@HMRCcustomers No this is my first job',\n",
       " 'ID': 0,\n",
       " 'Label': 2,\n",
       " 'text_label': 'no complaint'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "\n",
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "print(classes)\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "print(dataset)\n",
    "dataset[\"train\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe12d4d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running tokenizer on dataset: 100%|██████████| 50/50 [00:00<00:00, 3551.43 examples/s]\n",
      "Running tokenizer on dataset: 100%|██████████| 3399/3399 [00:00<00:00, 8558.01 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# data preprocessing\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets, add_special_tokens=False)  # don't add bos token because we concatenate with inputs\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.eos_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"train\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "641b21fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Running tokenizer on dataset: 100%|██████████| 3399/3399 [00:00<00:00, 17380.64 examples/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,  74757,  64626,  12384,  44639,    613,\n",
       "           52282,   2670,  79920,   3344,   1002,    368,  17646,  14472,   8348,\n",
       "             664,    718,      4,  19036,     17,  31849,     17,   6312,     76,\n",
       "              44,  62470,     56,     91,     50,  14839,     21,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3, 227985,   5484,    915,    405, 187059,\n",
       "            2256,    664,   2550,  18833,  18607, 162467,      4,   1387,   6199,\n",
       "            3291,  23405,    613,   4657,  17082,    566,   3432,    368,  78851,\n",
       "            1185,  61273,  23181,   1553,  15596,    212, 116057,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3, 227985,   5484,\n",
       "             915,  39762,   2566,  22253,   6201,  75701,     15,    632,    718,\n",
       "            5840,  10006,   6201,  18881,    427,   3804,  19528,    267, 158974,\n",
       "            1320,    368,  10029,    632,  49666,     92,     34,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3, 227985,   5484,    915,   2566, 104565,   8695,   2089,   6140,\n",
       "          109676,  99579,   1369,    512,    368,   4570,     54,    632,    368,\n",
       "            1503, 241485, 132226,     15,    982,    727,   1152,  18100,    861,\n",
       "           32596,  77597, 168154,   1306, 132226,   4346,  87843,     17, 130462,\n",
       "             364,  32923,     89,     53,   8309,     20,     75,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3, 227985,   5484,    915,   2566,\n",
       "           14173,   2960,  29906,    387,  20706,  49337,   1369,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3, 227985,   5484,    915,   2566, 219553,  45736,\n",
       "           36876,   1713,     72,    707, 187205,  13002, 177324,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3, 227985,   5484,    915,   2566, 233938,  28518,  13716,\n",
       "             427,  28146,   1119,  17918,     17, 236706,    368, 214997,   7555,\n",
       "           48659,   5276,  21600,    343,     17,  51416,  22403,    318,   1531,\n",
       "            1306,   1130,  20934,    567, 101161, 184849,  87843,     17,   1594,\n",
       "           15231,   2052,  16642,     20,   7180,     80,     26,  77658,    915,\n",
       "             210],\n",
       "         [     3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "               3,      3,      3,      3,      3,      3,      3,      3,      3,\n",
       "          227985,   5484,    915,   2566,     80,   2068,    479,   2566,     80,\n",
       "            1376,    878, 147587,   3904,    632,    368,   6084,  65673,  78851,\n",
       "           11736,  15527,  19082,  33151,    461,     17,  45575,  17887,    632,\n",
       "            5219,  14216,  68870,   5967,   1841,   4346,  87843,     17,   1594,\n",
       "           14512,     27,     71,   8184,     19,    290,  63748,  77658,    915,\n",
       "             210]]),\n",
       " 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test_preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "    return model_inputs\n",
    "\n",
    "\n",
    "test_dataset = dataset[\"test\"].map(\n",
    "    test_preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)\n",
    "next(iter(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "218df807",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "425"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# show dataset size\n",
    "len(test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa55f803",
   "metadata": {},
   "source": [
    "## Train the LM with LNTuning\n",
    "1. Create the base LM.\n",
    "2. Only activate the LayerNorm layers in the LM for training.\n",
    "3. Train the LM on the training dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a773e092",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 100,352 || all params: 559,314,944 || trainable%: 0.017941948642087417\n"
     ]
    }
   ],
   "source": [
    "# 1. creating the base LM\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "# 2. Only activate the LayerNorm layers in the Attention blocks in the LM for training\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b2f91568",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup the optimizer and lr scheduler\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e4fb69fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00,  7.09it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 23.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0: train_ppl=tensor(8.1918, device='cuda:0') train_epoch_loss=tensor(2.1031, device='cuda:0') eval_ppl=tensor(2.1760, device='cuda:0') eval_epoch_loss=tensor(0.7775, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.88it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 23.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1: train_ppl=tensor(1.8009, device='cuda:0') train_epoch_loss=tensor(0.5883, device='cuda:0') eval_ppl=tensor(2.1198, device='cuda:0') eval_epoch_loss=tensor(0.7513, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.87it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 23.08it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2: train_ppl=tensor(2.0387, device='cuda:0') train_epoch_loss=tensor(0.7123, device='cuda:0') eval_ppl=tensor(1.6793, device='cuda:0') eval_epoch_loss=tensor(0.5184, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.92it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 23.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3: train_ppl=tensor(1.4885, device='cuda:0') train_epoch_loss=tensor(0.3978, device='cuda:0') eval_ppl=tensor(1.2918, device='cuda:0') eval_epoch_loss=tensor(0.2561, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.89it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 23.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4: train_ppl=tensor(1.3062, device='cuda:0') train_epoch_loss=tensor(0.2671, device='cuda:0') eval_ppl=tensor(1.3259, device='cuda:0') eval_epoch_loss=tensor(0.2821, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.79it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=5: train_ppl=tensor(1.3129, device='cuda:0') train_epoch_loss=tensor(0.2722, device='cuda:0') eval_ppl=tensor(1.2315, device='cuda:0') eval_epoch_loss=tensor(0.2082, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.83it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=6: train_ppl=tensor(1.2605, device='cuda:0') train_epoch_loss=tensor(0.2315, device='cuda:0') eval_ppl=tensor(1.2705, device='cuda:0') eval_epoch_loss=tensor(0.2394, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.87it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=7: train_ppl=tensor(1.2452, device='cuda:0') train_epoch_loss=tensor(0.2193, device='cuda:0') eval_ppl=tensor(1.2103, device='cuda:0') eval_epoch_loss=tensor(0.1909, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.88it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=8: train_ppl=tensor(1.2185, device='cuda:0') train_epoch_loss=tensor(0.1976, device='cuda:0') eval_ppl=tensor(1.2127, device='cuda:0') eval_epoch_loss=tensor(0.1929, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.83it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=9: train_ppl=tensor(1.1868, device='cuda:0') train_epoch_loss=tensor(0.1713, device='cuda:0') eval_ppl=tensor(1.1765, device='cuda:0') eval_epoch_loss=tensor(0.1625, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.83it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=10: train_ppl=tensor(1.1905, device='cuda:0') train_epoch_loss=tensor(0.1744, device='cuda:0') eval_ppl=tensor(1.1539, device='cuda:0') eval_epoch_loss=tensor(0.1431, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.80it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=11: train_ppl=tensor(1.1475, device='cuda:0') train_epoch_loss=tensor(0.1376, device='cuda:0') eval_ppl=tensor(1.1238, device='cuda:0') eval_epoch_loss=tensor(0.1167, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.81it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=12: train_ppl=tensor(1.1099, device='cuda:0') train_epoch_loss=tensor(0.1043, device='cuda:0') eval_ppl=tensor(1.0859, device='cuda:0') eval_epoch_loss=tensor(0.0824, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.77it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=13: train_ppl=tensor(1.0798, device='cuda:0') train_epoch_loss=tensor(0.0768, device='cuda:0') eval_ppl=tensor(1.1151, device='cuda:0') eval_epoch_loss=tensor(0.1089, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.87it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=14: train_ppl=tensor(1.0622, device='cuda:0') train_epoch_loss=tensor(0.0604, device='cuda:0') eval_ppl=tensor(1.0347, device='cuda:0') eval_epoch_loss=tensor(0.0341, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.86it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=15: train_ppl=tensor(1.0188, device='cuda:0') train_epoch_loss=tensor(0.0186, device='cuda:0') eval_ppl=tensor(1.0169, device='cuda:0') eval_epoch_loss=tensor(0.0168, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.87it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=16: train_ppl=tensor(1.0085, device='cuda:0') train_epoch_loss=tensor(0.0085, device='cuda:0') eval_ppl=tensor(1.0047, device='cuda:0') eval_epoch_loss=tensor(0.0047, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.80it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=17: train_ppl=tensor(1.0041, device='cuda:0') train_epoch_loss=tensor(0.0041, device='cuda:0') eval_ppl=tensor(1.0013, device='cuda:0') eval_epoch_loss=tensor(0.0013, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.82it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=18: train_ppl=tensor(1.0010, device='cuda:0') train_epoch_loss=tensor(0.0010, device='cuda:0') eval_ppl=tensor(1.0010, device='cuda:0') eval_epoch_loss=tensor(0.0010, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.77it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=19: train_ppl=tensor(1.0007, device='cuda:0') train_epoch_loss=tensor(0.0007, device='cuda:0') eval_ppl=tensor(1.0005, device='cuda:0') eval_epoch_loss=tensor(0.0005, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.77it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=20: train_ppl=tensor(1.0004, device='cuda:0') train_epoch_loss=tensor(0.0004, device='cuda:0') eval_ppl=tensor(1.0004, device='cuda:0') eval_epoch_loss=tensor(0.0004, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.78it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=21: train_ppl=tensor(1.0003, device='cuda:0') train_epoch_loss=tensor(0.0003, device='cuda:0') eval_ppl=tensor(1.0003, device='cuda:0') eval_epoch_loss=tensor(0.0003, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.73it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=22: train_ppl=tensor(1.0003, device='cuda:0') train_epoch_loss=tensor(0.0003, device='cuda:0') eval_ppl=tensor(1.0003, device='cuda:0') eval_epoch_loss=tensor(0.0003, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.68it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=23: train_ppl=tensor(1.0003, device='cuda:0') train_epoch_loss=tensor(0.0003, device='cuda:0') eval_ppl=tensor(1.0003, device='cuda:0') eval_epoch_loss=tensor(0.0003, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.79it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=24: train_ppl=tensor(1.0003, device='cuda:0') train_epoch_loss=tensor(0.0003, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.69it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=25: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.66it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=26: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.75it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=27: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.76it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=28: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.73it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=29: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.73it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=30: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.72it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=31: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.75it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=32: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.72it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=33: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0002, device='cuda:0') eval_epoch_loss=tensor(0.0002, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.72it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=34: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.70it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=35: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.68it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=36: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.72it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=37: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.64it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=38: train_ppl=tensor(1.0002, device='cuda:0') train_epoch_loss=tensor(0.0002, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.66it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=39: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.71it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=40: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.64it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=41: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.64it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=42: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.68it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=43: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.70it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.54it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=44: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.67it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=45: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.69it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=46: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.66it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=47: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.60it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=48: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 7/7 [00:00<00:00, 10.62it/s]\n",
      "100%|██████████| 7/7 [00:00<00:00, 22.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=49: train_ppl=tensor(1.0001, device='cuda:0') train_epoch_loss=tensor(0.0001, device='cuda:0') eval_ppl=tensor(1.0001, device='cuda:0') eval_epoch_loss=tensor(0.0001, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 3. train the LM on the training dataset\n",
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        #         print(batch)\n",
    "        #         print(batch[\"input_ids\"].shape)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbf339a2",
   "metadata": {},
   "source": [
    "## Test the LM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "53752a7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,\n",
      "           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,\n",
      "            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,\n",
      "          10021,    897,    415,  10136,   6497,    381,    915,   5025,  51950,\n",
      "          66869,   5955,    272,  20311,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566, 226154, 126015,   5385,    259, 239364,\n",
      "           3396,  70823,   5853,     17,  57247,   1231, 191040,   5025,   7869,\n",
      "            375,   2324, 149349,     12,    415, 122321,    897,    415,  10136,\n",
      "          10021,    897,    415,  10136,   6497,    381,    915,   5025,  51950,\n",
      "          66869,   5955,    272,  20311,  77658,    915,    210,   1936, 106863,\n",
      "              2,   1936, 106863,      2,   1936, 106863,      2,   1936]],\n",
      "       device='cuda:0')\n",
      "['Tweet text : @TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand &gt; different sizing Label : no complaintno complaintno complaintno']\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "i = 33\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8f35152",
   "metadata": {},
   "source": [
    "## Save the trainable LM weights (LayerNorm layers)\n",
    "You can push model to hub or save model locally. \n",
    "\n",
    "- Option1: Push the model to Hugging Face Hub:\n",
    "\n",
    "    ```python\n",
    "    model.push_to_hub(\n",
    "        f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\"),\n",
    "        token = \"hf_...\"\n",
    "    )\n",
    "    ```\n",
    "    token (`bool` or `str`, *optional*):\n",
    "        `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated\n",
    "        when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`\n",
    "        is not specified.\n",
    "        Or you can get your token from https://huggingface.co/settings/token\n",
    "    ```\n",
    "- Option2: Save model locally:\n",
    "\n",
    "    ```python\n",
    "    peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\"/\", \"_\")\n",
    "    model.save_pretrained(peft_model_id)\n",
    "    ```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d8ba1f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving model\n",
    "peft_model_id = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "model.save_pretrained(peft_model_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dd7ab9c",
   "metadata": {},
   "source": [
    "## Test the LM using LNTuning loaded from saved weights\n",
    "1. load the LNTuning configuration\n",
    "2. load the base LM\n",
    "3. merge the LNTuning weights into the base LM using the PEFT config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4d9476e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "# load the LNTuning config\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "# load the base LM\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "# merge LNTuning weights into the base LM\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ebe174a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "@greateranglia Ok thanks...\n",
      "{'input_ids': tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n",
      "tensor([[227985,   5484,    915,   2566,  14173,   2960,  29906,    387,  20706,\n",
      "          49337,   1369,  77658,    915,    210,   1936, 106863,      2,   1936,\n",
      "         106863,      2,   1936, 106863,      2,   1936]], device='cuda:0')\n",
      "['Tweet text : @greateranglia Ok thanks... Label : no complaintno complaintno complaintno']\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "model.eval()\n",
    "i = 4\n",
    "inputs = tokenizer(f'{text_column} : {dataset[\"test\"][i][\"Tweet text\"]} Label : ', return_tensors=\"pt\")\n",
    "print(dataset[\"test\"][i][\"Tweet text\"])\n",
    "print(inputs)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(outputs)\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
