{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119e8ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset(\"glue\", 'sst2')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1742088",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "model_name = \"FacebookAI/roberta-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "\n",
    "\n",
    "col_to_delete = ['idx', 'sentence']\n",
    "\n",
    "def preprocessing_function(examples):\n",
    "    return tokenizer(examples['sentence'], truncation=True, max_length=128)\n",
    "\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44dc7ed6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=2).to('cuda')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "03e57b24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "\n",
    "\n",
    "    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    precision = metrics.precision_score(labels, predictions, average=\"macro\")\n",
    "    recall = metrics.recall_score(labels, predictions, average=\"macro\")\n",
    "    f1 = metrics.f1_score(labels, predictions, average=\"macro\")\n",
    "    accuracy = metrics.accuracy_score(labels, predictions)\n",
    "    \n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1-score\": f1, 'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "58705a4d",
   "metadata": {},
   "outputs": [
   
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 1 | Mode=row | Pos=0] | Learning Rate: 0.000300 | Rank: 5\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 420,096\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1000/8420 00:47 < 05:51, 21.12 it/s, Epoch 0/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.693200</td>\n",
       "      <td>0.693334</td>\n",
       "      <td>0.254587</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.337386</td>\n",
       "      <td>0.509174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.589600</td>\n",
       "      <td>0.287706</td>\n",
       "      <td>0.892619</td>\n",
       "      <td>0.892544</td>\n",
       "      <td>0.892201</td>\n",
       "      <td>0.892202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.402100</td>\n",
       "      <td>0.250897</td>\n",
       "      <td>0.903636</td>\n",
       "      <td>0.903764</td>\n",
       "      <td>0.903657</td>\n",
       "      <td>0.903670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.320900</td>\n",
       "      <td>0.232684</td>\n",
       "      <td>0.918523</td>\n",
       "      <td>0.917951</td>\n",
       "      <td>0.917420</td>\n",
       "      <td>0.917431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.309100</td>\n",
       "      <td>0.250495</td>\n",
       "      <td>0.904942</td>\n",
       "      <td>0.901838</td>\n",
       "      <td>0.902239</td>\n",
       "      <td>0.902523</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.302300</td>\n",
       "      <td>0.218723</td>\n",
       "      <td>0.915146</td>\n",
       "      <td>0.915067</td>\n",
       "      <td>0.915101</td>\n",
       "      <td>0.915138</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.286900</td>\n",
       "      <td>0.285301</td>\n",
       "      <td>0.896034</td>\n",
       "      <td>0.885061</td>\n",
       "      <td>0.885480</td>\n",
       "      <td>0.886468</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.275700</td>\n",
       "      <td>0.236103</td>\n",
       "      <td>0.912243</td>\n",
       "      <td>0.907384</td>\n",
       "      <td>0.907877</td>\n",
       "      <td>0.908257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.273900</td>\n",
       "      <td>0.213088</td>\n",
       "      <td>0.922018</td>\n",
       "      <td>0.922160</td>\n",
       "      <td>0.922012</td>\n",
       "      <td>0.922018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.273700</td>\n",
       "      <td>0.314489</td>\n",
       "      <td>0.896808</td>\n",
       "      <td>0.882557</td>\n",
       "      <td>0.882899</td>\n",
       "      <td>0.884174</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 2 | Mode=column | Pos=0] | Learning Rate: 0.000300 | Rank: 5\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 418,570\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2000/8420 00:51 < 05:31, 19.39 it/s, Epoch 0/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.281800</td>\n",
       "      <td>0.273941</td>\n",
       "      <td>0.906495</td>\n",
       "      <td>0.897870</td>\n",
       "      <td>0.898390</td>\n",
       "      <td>0.899083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.248400</td>\n",
       "      <td>0.222406</td>\n",
       "      <td>0.924033</td>\n",
       "      <td>0.922781</td>\n",
       "      <td>0.923055</td>\n",
       "      <td>0.923165</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.248800</td>\n",
       "      <td>0.209030</td>\n",
       "      <td>0.918601</td>\n",
       "      <td>0.918740</td>\n",
       "      <td>0.918573</td>\n",
       "      <td>0.918578</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.250300</td>\n",
       "      <td>0.212010</td>\n",
       "      <td>0.921972</td>\n",
       "      <td>0.922034</td>\n",
       "      <td>0.921998</td>\n",
       "      <td>0.922018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.252900</td>\n",
       "      <td>0.196947</td>\n",
       "      <td>0.928875</td>\n",
       "      <td>0.928875</td>\n",
       "      <td>0.928875</td>\n",
       "      <td>0.928899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.239600</td>\n",
       "      <td>0.203127</td>\n",
       "      <td>0.926572</td>\n",
       "      <td>0.926707</td>\n",
       "      <td>0.926596</td>\n",
       "      <td>0.926606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.236500</td>\n",
       "      <td>0.207072</td>\n",
       "      <td>0.925751</td>\n",
       "      <td>0.922487</td>\n",
       "      <td>0.922941</td>\n",
       "      <td>0.923165</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.235700</td>\n",
       "      <td>0.199067</td>\n",
       "      <td>0.927793</td>\n",
       "      <td>0.927665</td>\n",
       "      <td>0.927718</td>\n",
       "      <td>0.927752</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.239000</td>\n",
       "      <td>0.199155</td>\n",
       "      <td>0.921985</td>\n",
       "      <td>0.922118</td>\n",
       "      <td>0.922008</td>\n",
       "      <td>0.922018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.203500</td>\n",
       "      <td>0.209142</td>\n",
       "      <td>0.930071</td>\n",
       "      <td>0.927118</td>\n",
       "      <td>0.927559</td>\n",
       "      <td>0.927752</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   1%|          | 10/1536 [01:43<4:24:22, 10.39s/block]Adjusted position to 0 to fit rank 5 in 2 rows\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 3 | Mode=row | Pos=5] | Learning Rate: 0.000300 | Rank: 5\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 420,096\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3000/8420 00:46 < 04:14, 21.30 it/s, Epoch 1/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.217400</td>\n",
       "      <td>0.201314</td>\n",
       "      <td>0.935736</td>\n",
       "      <td>0.934211</td>\n",
       "      <td>0.934528</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.234900</td>\n",
       "      <td>0.182783</td>\n",
       "      <td>0.939425</td>\n",
       "      <td>0.939052</td>\n",
       "      <td>0.939178</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.206700</td>\n",
       "      <td>0.208166</td>\n",
       "      <td>0.929785</td>\n",
       "      <td>0.927160</td>\n",
       "      <td>0.927576</td>\n",
       "      <td>0.927752</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.198900</td>\n",
       "      <td>0.189291</td>\n",
       "      <td>0.941027</td>\n",
       "      <td>0.940052</td>\n",
       "      <td>0.940296</td>\n",
       "      <td>0.940367</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.210400</td>\n",
       "      <td>0.197863</td>\n",
       "      <td>0.934132</td>\n",
       "      <td>0.931790</td>\n",
       "      <td>0.932190</td>\n",
       "      <td>0.932339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.214500</td>\n",
       "      <td>0.200077</td>\n",
       "      <td>0.939296</td>\n",
       "      <td>0.937632</td>\n",
       "      <td>0.937968</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.218700</td>\n",
       "      <td>0.196386</td>\n",
       "      <td>0.939088</td>\n",
       "      <td>0.937674</td>\n",
       "      <td>0.937979</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.206100</td>\n",
       "      <td>0.185473</td>\n",
       "      <td>0.934717</td>\n",
       "      <td>0.934843</td>\n",
       "      <td>0.934631</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.204600</td>\n",
       "      <td>0.198573</td>\n",
       "      <td>0.940141</td>\n",
       "      <td>0.938842</td>\n",
       "      <td>0.939133</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.197100</td>\n",
       "      <td>0.188428</td>\n",
       "      <td>0.937042</td>\n",
       "      <td>0.936800</td>\n",
       "      <td>0.936890</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   1%|          | 15/1536 [02:32<4:15:54, 10.09s/block]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 4 | Mode=column | Pos=5] | Learning Rate: 0.000300 | Rank: 5\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 418,570\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4000/8420 00:53 < 03:54, 18.82 it/s, Epoch 1/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.212700</td>\n",
       "      <td>0.198118</td>\n",
       "      <td>0.936428</td>\n",
       "      <td>0.935464</td>\n",
       "      <td>0.935704</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.196900</td>\n",
       "      <td>0.183277</td>\n",
       "      <td>0.935780</td>\n",
       "      <td>0.935927</td>\n",
       "      <td>0.935774</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.199600</td>\n",
       "      <td>0.186693</td>\n",
       "      <td>0.934614</td>\n",
       "      <td>0.934758</td>\n",
       "      <td>0.934626</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.204900</td>\n",
       "      <td>0.191869</td>\n",
       "      <td>0.940141</td>\n",
       "      <td>0.938842</td>\n",
       "      <td>0.939133</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.196600</td>\n",
       "      <td>0.191153</td>\n",
       "      <td>0.932423</td>\n",
       "      <td>0.932548</td>\n",
       "      <td>0.932337</td>\n",
       "      <td>0.932339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.186500</td>\n",
       "      <td>0.200575</td>\n",
       "      <td>0.929065</td>\n",
       "      <td>0.925949</td>\n",
       "      <td>0.926401</td>\n",
       "      <td>0.926606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.193600</td>\n",
       "      <td>0.194650</td>\n",
       "      <td>0.938898</td>\n",
       "      <td>0.937716</td>\n",
       "      <td>0.937990</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.192000</td>\n",
       "      <td>0.187540</td>\n",
       "      <td>0.937357</td>\n",
       "      <td>0.936674</td>\n",
       "      <td>0.936866</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.206300</td>\n",
       "      <td>0.186016</td>\n",
       "      <td>0.937234</td>\n",
       "      <td>0.936716</td>\n",
       "      <td>0.936875</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.195600</td>\n",
       "      <td>0.186207</td>\n",
       "      <td>0.933486</td>\n",
       "      <td>0.933632</td>\n",
       "      <td>0.933481</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   1%|▏         | 20/1536 [03:26<4:23:46, 10.44s/block]Adjusted position to 0 to fit rank 5 in 2 rows\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 5 | Mode=row | Pos=10] | Learning Rate: 0.000300 | Rank: 5\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 420,096\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='5000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [5000/8420 00:50 < 02:51, 19.95 it/s, Epoch 2/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.206500</td>\n",
       "      <td>0.197552</td>\n",
       "      <td>0.938898</td>\n",
       "      <td>0.937716</td>\n",
       "      <td>0.937990</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.194900</td>\n",
       "      <td>0.208730</td>\n",
       "      <td>0.934914</td>\n",
       "      <td>0.933001</td>\n",
       "      <td>0.933360</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.197200</td>\n",
       "      <td>0.189146</td>\n",
       "      <td>0.935736</td>\n",
       "      <td>0.935800</td>\n",
       "      <td>0.935763</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.194600</td>\n",
       "      <td>0.184504</td>\n",
       "      <td>0.939961</td>\n",
       "      <td>0.938884</td>\n",
       "      <td>0.939143</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.172300</td>\n",
       "      <td>0.182516</td>\n",
       "      <td>0.941513</td>\n",
       "      <td>0.941473</td>\n",
       "      <td>0.941492</td>\n",
       "      <td>0.941514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.197200</td>\n",
       "      <td>0.184202</td>\n",
       "      <td>0.937042</td>\n",
       "      <td>0.936800</td>\n",
       "      <td>0.936890</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.188700</td>\n",
       "      <td>0.194509</td>\n",
       "      <td>0.938728</td>\n",
       "      <td>0.937758</td>\n",
       "      <td>0.938000</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.183600</td>\n",
       "      <td>0.201765</td>\n",
       "      <td>0.936428</td>\n",
       "      <td>0.935464</td>\n",
       "      <td>0.935704</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.187500</td>\n",
       "      <td>0.194409</td>\n",
       "      <td>0.938442</td>\n",
       "      <td>0.937842</td>\n",
       "      <td>0.938018</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.179700</td>\n",
       "      <td>0.197374</td>\n",
       "      <td>0.934690</td>\n",
       "      <td>0.933043</td>\n",
       "      <td>0.933373</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 25/1536 [04:18<4:22:09, 10.41s/block]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  [Rank Reduced] Mode=row: Loss 0.1974 >= 0.1862 → Rank 5 → 4\n",
      "\n",
      "[Slice 6 | Mode=column | Pos=10] | Learning Rate: 0.000200 | Rank: 4\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 334,856\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6000/8420 00:54 < 02:12, 18.31 it/s, Epoch 2/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.188900</td>\n",
       "      <td>0.188448</td>\n",
       "      <td>0.933433</td>\n",
       "      <td>0.931917</td>\n",
       "      <td>0.932230</td>\n",
       "      <td>0.932339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.179000</td>\n",
       "      <td>0.201320</td>\n",
       "      <td>0.932387</td>\n",
       "      <td>0.930749</td>\n",
       "      <td>0.931075</td>\n",
       "      <td>0.931193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.173200</td>\n",
       "      <td>0.195008</td>\n",
       "      <td>0.931829</td>\n",
       "      <td>0.930875</td>\n",
       "      <td>0.931111</td>\n",
       "      <td>0.931193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.191100</td>\n",
       "      <td>0.187633</td>\n",
       "      <td>0.936892</td>\n",
       "      <td>0.936926</td>\n",
       "      <td>0.936908</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.188500</td>\n",
       "      <td>0.188878</td>\n",
       "      <td>0.932536</td>\n",
       "      <td>0.932169</td>\n",
       "      <td>0.932292</td>\n",
       "      <td>0.932339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.196200</td>\n",
       "      <td>0.183860</td>\n",
       "      <td>0.937234</td>\n",
       "      <td>0.936716</td>\n",
       "      <td>0.936875</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.178800</td>\n",
       "      <td>0.193852</td>\n",
       "      <td>0.934129</td>\n",
       "      <td>0.933169</td>\n",
       "      <td>0.933407</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.169900</td>\n",
       "      <td>0.194995</td>\n",
       "      <td>0.933847</td>\n",
       "      <td>0.933253</td>\n",
       "      <td>0.933427</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.171000</td>\n",
       "      <td>0.198809</td>\n",
       "      <td>0.934598</td>\n",
       "      <td>0.934632</td>\n",
       "      <td>0.934614</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.168400</td>\n",
       "      <td>0.198596</td>\n",
       "      <td>0.930937</td>\n",
       "      <td>0.929664</td>\n",
       "      <td>0.929946</td>\n",
       "      <td>0.930046</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 29/1536 [05:14<4:48:20, 11.48s/block]Adjusted position to 0 to fit rank 3 in 2 rows\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  [Rank Reduced] Mode=column: Loss 0.1986 >= 0.1974 → Rank 4 → 3\n",
      "\n",
      "[Slice 7 | Mode=row | Pos=13] | Learning Rate: 0.000100 | Rank: 3\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 252,672\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7000/8420 00:47 < 01:07, 21.10 it/s, Epoch 3/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.177400</td>\n",
       "      <td>0.201287</td>\n",
       "      <td>0.932182</td>\n",
       "      <td>0.930791</td>\n",
       "      <td>0.931088</td>\n",
       "      <td>0.931193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.177200</td>\n",
       "      <td>0.192906</td>\n",
       "      <td>0.937128</td>\n",
       "      <td>0.936758</td>\n",
       "      <td>0.936883</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.161800</td>\n",
       "      <td>0.195279</td>\n",
       "      <td>0.935935</td>\n",
       "      <td>0.935632</td>\n",
       "      <td>0.935739</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6400</td>\n",
       "      <td>0.163500</td>\n",
       "      <td>0.200465</td>\n",
       "      <td>0.931437</td>\n",
       "      <td>0.931001</td>\n",
       "      <td>0.931140</td>\n",
       "      <td>0.931193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>0.146400</td>\n",
       "      <td>0.202960</td>\n",
       "      <td>0.939173</td>\n",
       "      <td>0.939263</td>\n",
       "      <td>0.939207</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6600</td>\n",
       "      <td>0.159800</td>\n",
       "      <td>0.200541</td>\n",
       "      <td>0.936030</td>\n",
       "      <td>0.935590</td>\n",
       "      <td>0.935731</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6700</td>\n",
       "      <td>0.160100</td>\n",
       "      <td>0.198809</td>\n",
       "      <td>0.935857</td>\n",
       "      <td>0.935674</td>\n",
       "      <td>0.935746</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6800</td>\n",
       "      <td>0.174600</td>\n",
       "      <td>0.198519</td>\n",
       "      <td>0.934832</td>\n",
       "      <td>0.934464</td>\n",
       "      <td>0.934588</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6900</td>\n",
       "      <td>0.187300</td>\n",
       "      <td>0.189704</td>\n",
       "      <td>0.935736</td>\n",
       "      <td>0.935800</td>\n",
       "      <td>0.935763</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>0.176200</td>\n",
       "      <td>0.194961</td>\n",
       "      <td>0.936030</td>\n",
       "      <td>0.935590</td>\n",
       "      <td>0.935731</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 32/1536 [06:04<5:16:57, 12.64s/block]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 8 | Mode=column | Pos=13] | Learning Rate: 0.000100 | Rank: 3\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 251,142\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='8000' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [8000/8420 00:52 < 00:22, 18.96 it/s, Epoch 3/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>7100</td>\n",
       "      <td>0.164500</td>\n",
       "      <td>0.197828</td>\n",
       "      <td>0.935201</td>\n",
       "      <td>0.934337</td>\n",
       "      <td>0.934561</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7200</td>\n",
       "      <td>0.163100</td>\n",
       "      <td>0.194187</td>\n",
       "      <td>0.935201</td>\n",
       "      <td>0.934337</td>\n",
       "      <td>0.934561</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7300</td>\n",
       "      <td>0.165800</td>\n",
       "      <td>0.192290</td>\n",
       "      <td>0.936144</td>\n",
       "      <td>0.935548</td>\n",
       "      <td>0.935723</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7400</td>\n",
       "      <td>0.166600</td>\n",
       "      <td>0.194287</td>\n",
       "      <td>0.938153</td>\n",
       "      <td>0.937968</td>\n",
       "      <td>0.938041</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7500</td>\n",
       "      <td>0.176600</td>\n",
       "      <td>0.194670</td>\n",
       "      <td>0.934937</td>\n",
       "      <td>0.934422</td>\n",
       "      <td>0.934579</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7600</td>\n",
       "      <td>0.159200</td>\n",
       "      <td>0.194748</td>\n",
       "      <td>0.937042</td>\n",
       "      <td>0.936800</td>\n",
       "      <td>0.936890</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7700</td>\n",
       "      <td>0.160900</td>\n",
       "      <td>0.194277</td>\n",
       "      <td>0.935857</td>\n",
       "      <td>0.935674</td>\n",
       "      <td>0.935746</td>\n",
       "      <td>0.935780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7800</td>\n",
       "      <td>0.177500</td>\n",
       "      <td>0.190127</td>\n",
       "      <td>0.937234</td>\n",
       "      <td>0.936716</td>\n",
       "      <td>0.936875</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7900</td>\n",
       "      <td>0.165300</td>\n",
       "      <td>0.195602</td>\n",
       "      <td>0.930937</td>\n",
       "      <td>0.929664</td>\n",
       "      <td>0.929946</td>\n",
       "      <td>0.930046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8000</td>\n",
       "      <td>0.172200</td>\n",
       "      <td>0.187718</td>\n",
       "      <td>0.939218</td>\n",
       "      <td>0.939179</td>\n",
       "      <td>0.939197</td>\n",
       "      <td>0.939220</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 35/1536 [06:58<5:51:17, 14.04s/block]Adjusted position to 0 to fit rank 3 in 2 rows\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[Slice 9 | Mode=row | Pos=16] | Learning Rate: 0.000100 | Rank: 3\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 252,672\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='8420' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [8420/8420 00:21, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>8100</td>\n",
       "      <td>0.175100</td>\n",
       "      <td>0.192720</td>\n",
       "      <td>0.934129</td>\n",
       "      <td>0.933169</td>\n",
       "      <td>0.933407</td>\n",
       "      <td>0.933486</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8200</td>\n",
       "      <td>0.170500</td>\n",
       "      <td>0.193987</td>\n",
       "      <td>0.937128</td>\n",
       "      <td>0.936758</td>\n",
       "      <td>0.936883</td>\n",
       "      <td>0.936927</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8300</td>\n",
       "      <td>0.160400</td>\n",
       "      <td>0.199654</td>\n",
       "      <td>0.938231</td>\n",
       "      <td>0.937926</td>\n",
       "      <td>0.938034</td>\n",
       "      <td>0.938073</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8400</td>\n",
       "      <td>0.165100</td>\n",
       "      <td>0.197724</td>\n",
       "      <td>0.934937</td>\n",
       "      <td>0.934422</td>\n",
       "      <td>0.934579</td>\n",
       "      <td>0.934633</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 38/1536 [07:21<5:08:04, 12.34s/block]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  [Rank Reduced] Mode=row: Loss 0.1977 >= 0.1877 → Rank 3 → 2\n",
      "\n",
      "[Slice 10 | Mode=column | Pos=16] | Learning Rate: 0.000010 | Rank: 2\n",
      "  Total Parameters: 124,647,170\n",
      "  Trainable Parameters: 167,428\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='8420' max='8420' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [8420/8420 : < :, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training PEFT Blocks:   2%|▏         | 38/1536 [07:24<4:51:52, 11.69s/block]\n"
     ]
    }
   ],
   "source": [
    "from SliceTrainer import SliceTrainer\n",
    "# Define your HuggingFace TrainingArguments as usual\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=3e-4,\n",
    "    remove_unused_columns=False,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    num_train_epochs=4,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"no\",\n",
    "    logging_steps=100,\n",
    "    save_steps=100,\n",
    "    eval_steps=100,\n",
    "    lr_scheduler_type=\"cosine\",\n",
    "    warmup_steps=10,\n",
    "    report_to=[],\n",
    "    fp16=True,\n",
    ")\n",
    "\n",
    "\n",
    "# Example instantiation:\n",
    "trainer = SliceTrainer(\n",
    "    model=model,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    "    training_args=training_args,\n",
    "    move_steps=1000,\n",
    "    rank=5,\n",
    "    learnig_rate_decay=0.0001,\n",
    "    min_learning_rate=0.00001,\n",
    "    position=0,\n",
    "    max_position=768,\n",
    "    bias=False,\n",
    "    peft_modes=(\"row\", \"column\"),        # or sjust (\"row\",), (\"column\",)\n",
    "    targets=None,                        # or pass target layer names\n",
    "    verbose=True,\n",
    "    tollerance=1,                        # Tolerance for loss rank chnages\n",
    "    rank_decay=1,                   # How much to reduce rank when loss does not improve\n",
    "    min_rank=1,                     # Minimum rank to reduce to\n",
    ")\n",
    "\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cb776bd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
