{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Downloading data: 100%|██████████| 218k/218k [00:00<00:00, 237kB/s]  \n",
      "Generating train split: 100%|██████████| 4439/4439 [00:00<00:00, 23806.99 examples/s]\n",
      "Generating validation split: 100%|██████████| 495/495 [00:00<00:00, 13273.62 examples/s]\n",
      "Generating test split: 100%|██████████| 4906/4906 [00:00<00:00, 24127.27 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('RobZamp/sick')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c0d83c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n",
    "from huggingface_hub import login\n",
    "\n",
    "# Log in using your Hugging Face token\n",
    "login(\"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\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2-medium and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, AutoModelForSequenceClassification\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "# from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "\n",
    "model_name = \"openai-community/gpt2-medium\"\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 4439/4439 [00:00<00:00, 7659.11 examples/s] \n",
      "Map: 100%|██████████| 495/495 [00:00<00:00, 10046.16 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 4439\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 495\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence-1:: {data_point['sentence_A']}. # Sentence-2: {data_point['sentence_B']} # Output: The similarity is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['relatedness_score'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 4439/4439 [00:00<00:00, 29501.88 examples/s]\n",
      "Map: 100%|██████████| 495/495 [00:00<00:00, 25900.13 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'input']\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# Sentence-1:: A person on a black motorbike is doing tricks with a jacket. # Sentence-2: A person is riding the bicycle on one wheel # Output: The similarity is'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=1e-3,\n",
    "    per_device_train_batch_size=14,\n",
    "    per_device_eval_batch_size=14,\n",
    "    gradient_accumulation_steps= 1,\n",
    "    num_train_epochs=20,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6360' max='6360' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6360/6360 20:55, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>8.256000</td>\n",
       "      <td>1.018912</td>\n",
       "      <td>0.798324</td>\n",
       "      <td>1.018912</td>\n",
       "      <td>1.009412</td>\n",
       "      <td>0.066667</td>\n",
       "      <td>-0.005257</td>\n",
       "      <td>0.074711</td>\n",
       "      <td>0.041039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.254700</td>\n",
       "      <td>1.176675</td>\n",
       "      <td>0.886709</td>\n",
       "      <td>1.176675</td>\n",
       "      <td>1.084746</td>\n",
       "      <td>0.064646</td>\n",
       "      <td>-0.160905</td>\n",
       "      <td>-0.064240</td>\n",
       "      <td>-0.047145</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.216900</td>\n",
       "      <td>1.054942</td>\n",
       "      <td>0.799963</td>\n",
       "      <td>1.054942</td>\n",
       "      <td>1.027104</td>\n",
       "      <td>0.088889</td>\n",
       "      <td>-0.040804</td>\n",
       "      <td>0.087913</td>\n",
       "      <td>0.124668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>1.247700</td>\n",
       "      <td>1.013406</td>\n",
       "      <td>0.803672</td>\n",
       "      <td>1.013406</td>\n",
       "      <td>1.006680</td>\n",
       "      <td>0.066667</td>\n",
       "      <td>0.000176</td>\n",
       "      <td>0.132851</td>\n",
       "      <td>0.171994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.218700</td>\n",
       "      <td>1.107566</td>\n",
       "      <td>0.859773</td>\n",
       "      <td>1.107565</td>\n",
       "      <td>1.052409</td>\n",
       "      <td>0.082828</td>\n",
       "      <td>-0.092722</td>\n",
       "      <td>0.150441</td>\n",
       "      <td>0.157975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>1.108500</td>\n",
       "      <td>1.094041</td>\n",
       "      <td>0.808960</td>\n",
       "      <td>1.094041</td>\n",
       "      <td>1.045964</td>\n",
       "      <td>0.070707</td>\n",
       "      <td>-0.079379</td>\n",
       "      <td>0.164030</td>\n",
       "      <td>0.163338</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>1.195800</td>\n",
       "      <td>1.154880</td>\n",
       "      <td>0.887710</td>\n",
       "      <td>1.154880</td>\n",
       "      <td>1.074653</td>\n",
       "      <td>0.084848</td>\n",
       "      <td>-0.139402</td>\n",
       "      <td>0.162937</td>\n",
       "      <td>0.185053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>1.099600</td>\n",
       "      <td>1.046936</td>\n",
       "      <td>0.833525</td>\n",
       "      <td>1.046936</td>\n",
       "      <td>1.023199</td>\n",
       "      <td>0.070707</td>\n",
       "      <td>-0.032905</td>\n",
       "      <td>0.058076</td>\n",
       "      <td>0.199153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>1.059600</td>\n",
       "      <td>1.057182</td>\n",
       "      <td>0.835496</td>\n",
       "      <td>1.057182</td>\n",
       "      <td>1.028193</td>\n",
       "      <td>0.082828</td>\n",
       "      <td>-0.043013</td>\n",
       "      <td>0.152284</td>\n",
       "      <td>0.181459</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.076600</td>\n",
       "      <td>1.072828</td>\n",
       "      <td>0.796368</td>\n",
       "      <td>1.072827</td>\n",
       "      <td>1.035774</td>\n",
       "      <td>0.082828</td>\n",
       "      <td>-0.058449</td>\n",
       "      <td>0.130355</td>\n",
       "      <td>0.196777</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>1.127700</td>\n",
       "      <td>1.095142</td>\n",
       "      <td>0.858003</td>\n",
       "      <td>1.095142</td>\n",
       "      <td>1.046490</td>\n",
       "      <td>0.088889</td>\n",
       "      <td>-0.080465</td>\n",
       "      <td>0.124953</td>\n",
       "      <td>0.178322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>1.138900</td>\n",
       "      <td>1.033697</td>\n",
       "      <td>0.804360</td>\n",
       "      <td>1.033696</td>\n",
       "      <td>1.016709</td>\n",
       "      <td>0.113131</td>\n",
       "      <td>-0.019843</td>\n",
       "      <td>0.145370</td>\n",
       "      <td>0.143787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>1.102600</td>\n",
       "      <td>1.212904</td>\n",
       "      <td>0.920220</td>\n",
       "      <td>1.212904</td>\n",
       "      <td>1.101319</td>\n",
       "      <td>0.088889</td>\n",
       "      <td>-0.196649</td>\n",
       "      <td>0.169092</td>\n",
       "      <td>0.197974</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>1.077000</td>\n",
       "      <td>1.057742</td>\n",
       "      <td>0.786387</td>\n",
       "      <td>1.057742</td>\n",
       "      <td>1.028466</td>\n",
       "      <td>0.078788</td>\n",
       "      <td>-0.043566</td>\n",
       "      <td>0.148057</td>\n",
       "      <td>0.203166</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>1.203800</td>\n",
       "      <td>1.000870</td>\n",
       "      <td>0.796359</td>\n",
       "      <td>1.000870</td>\n",
       "      <td>1.000435</td>\n",
       "      <td>0.058586</td>\n",
       "      <td>0.012544</td>\n",
       "      <td>0.119730</td>\n",
       "      <td>0.192606</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>1.083400</td>\n",
       "      <td>1.585936</td>\n",
       "      <td>0.931523</td>\n",
       "      <td>1.585936</td>\n",
       "      <td>1.259339</td>\n",
       "      <td>0.078788</td>\n",
       "      <td>-0.564681</td>\n",
       "      <td>0.181462</td>\n",
       "      <td>0.223538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>2.637400</td>\n",
       "      <td>1.047042</td>\n",
       "      <td>0.839272</td>\n",
       "      <td>1.047042</td>\n",
       "      <td>1.023251</td>\n",
       "      <td>0.090909</td>\n",
       "      <td>-0.033009</td>\n",
       "      <td>0.210457</td>\n",
       "      <td>0.233595</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>1.039700</td>\n",
       "      <td>1.026409</td>\n",
       "      <td>0.791695</td>\n",
       "      <td>1.026409</td>\n",
       "      <td>1.013119</td>\n",
       "      <td>0.074747</td>\n",
       "      <td>-0.012653</td>\n",
       "      <td>0.180180</td>\n",
       "      <td>0.274676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>1.061700</td>\n",
       "      <td>0.958416</td>\n",
       "      <td>0.774150</td>\n",
       "      <td>0.958416</td>\n",
       "      <td>0.978987</td>\n",
       "      <td>0.076768</td>\n",
       "      <td>0.054429</td>\n",
       "      <td>0.245204</td>\n",
       "      <td>0.252894</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>1.063800</td>\n",
       "      <td>0.962673</td>\n",
       "      <td>0.782853</td>\n",
       "      <td>0.962673</td>\n",
       "      <td>0.981159</td>\n",
       "      <td>0.084848</td>\n",
       "      <td>0.050229</td>\n",
       "      <td>0.270764</td>\n",
       "      <td>0.316711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>1.010300</td>\n",
       "      <td>1.080391</td>\n",
       "      <td>0.781745</td>\n",
       "      <td>1.080391</td>\n",
       "      <td>1.039418</td>\n",
       "      <td>0.076768</td>\n",
       "      <td>-0.065911</td>\n",
       "      <td>0.224155</td>\n",
       "      <td>0.320912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>1.070400</td>\n",
       "      <td>1.031383</td>\n",
       "      <td>0.768074</td>\n",
       "      <td>1.031382</td>\n",
       "      <td>1.015570</td>\n",
       "      <td>0.072727</td>\n",
       "      <td>-0.017560</td>\n",
       "      <td>0.252834</td>\n",
       "      <td>0.289858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.959300</td>\n",
       "      <td>0.992157</td>\n",
       "      <td>0.805596</td>\n",
       "      <td>0.992157</td>\n",
       "      <td>0.996071</td>\n",
       "      <td>0.088889</td>\n",
       "      <td>0.021140</td>\n",
       "      <td>0.254132</td>\n",
       "      <td>0.313293</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>1.022800</td>\n",
       "      <td>0.937261</td>\n",
       "      <td>0.760590</td>\n",
       "      <td>0.937261</td>\n",
       "      <td>0.968123</td>\n",
       "      <td>0.060606</td>\n",
       "      <td>0.075300</td>\n",
       "      <td>0.279627</td>\n",
       "      <td>0.334475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.990400</td>\n",
       "      <td>0.914661</td>\n",
       "      <td>0.759638</td>\n",
       "      <td>0.914661</td>\n",
       "      <td>0.956379</td>\n",
       "      <td>0.080808</td>\n",
       "      <td>0.097597</td>\n",
       "      <td>0.345370</td>\n",
       "      <td>0.374416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.945600</td>\n",
       "      <td>1.019007</td>\n",
       "      <td>0.751635</td>\n",
       "      <td>1.019007</td>\n",
       "      <td>1.009459</td>\n",
       "      <td>0.074747</td>\n",
       "      <td>-0.005351</td>\n",
       "      <td>0.355943</td>\n",
       "      <td>0.384212</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>1.082000</td>\n",
       "      <td>0.876809</td>\n",
       "      <td>0.732746</td>\n",
       "      <td>0.876809</td>\n",
       "      <td>0.936381</td>\n",
       "      <td>0.074747</td>\n",
       "      <td>0.134942</td>\n",
       "      <td>0.368317</td>\n",
       "      <td>0.387265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.901400</td>\n",
       "      <td>0.849300</td>\n",
       "      <td>0.716431</td>\n",
       "      <td>0.849300</td>\n",
       "      <td>0.921575</td>\n",
       "      <td>0.105051</td>\n",
       "      <td>0.162082</td>\n",
       "      <td>0.421326</td>\n",
       "      <td>0.406692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.896600</td>\n",
       "      <td>0.944457</td>\n",
       "      <td>0.733090</td>\n",
       "      <td>0.944457</td>\n",
       "      <td>0.971832</td>\n",
       "      <td>0.096970</td>\n",
       "      <td>0.068201</td>\n",
       "      <td>0.407505</td>\n",
       "      <td>0.412997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.948600</td>\n",
       "      <td>0.831796</td>\n",
       "      <td>0.702052</td>\n",
       "      <td>0.831796</td>\n",
       "      <td>0.912028</td>\n",
       "      <td>0.096970</td>\n",
       "      <td>0.179352</td>\n",
       "      <td>0.428510</td>\n",
       "      <td>0.451336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.852000</td>\n",
       "      <td>0.845330</td>\n",
       "      <td>0.698162</td>\n",
       "      <td>0.845330</td>\n",
       "      <td>0.919418</td>\n",
       "      <td>0.101010</td>\n",
       "      <td>0.166000</td>\n",
       "      <td>0.436905</td>\n",
       "      <td>0.449715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.860300</td>\n",
       "      <td>0.882085</td>\n",
       "      <td>0.700576</td>\n",
       "      <td>0.882085</td>\n",
       "      <td>0.939194</td>\n",
       "      <td>0.105051</td>\n",
       "      <td>0.129737</td>\n",
       "      <td>0.450442</td>\n",
       "      <td>0.455684</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.809100</td>\n",
       "      <td>0.819708</td>\n",
       "      <td>0.709688</td>\n",
       "      <td>0.819708</td>\n",
       "      <td>0.905377</td>\n",
       "      <td>0.086869</td>\n",
       "      <td>0.191277</td>\n",
       "      <td>0.474116</td>\n",
       "      <td>0.460701</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.776700</td>\n",
       "      <td>0.809152</td>\n",
       "      <td>0.725802</td>\n",
       "      <td>0.809152</td>\n",
       "      <td>0.899529</td>\n",
       "      <td>0.082828</td>\n",
       "      <td>0.201692</td>\n",
       "      <td>0.478443</td>\n",
       "      <td>0.456227</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.747300</td>\n",
       "      <td>0.803932</td>\n",
       "      <td>0.723117</td>\n",
       "      <td>0.803932</td>\n",
       "      <td>0.896623</td>\n",
       "      <td>0.076768</td>\n",
       "      <td>0.206842</td>\n",
       "      <td>0.498969</td>\n",
       "      <td>0.470895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.701000</td>\n",
       "      <td>0.864580</td>\n",
       "      <td>0.701229</td>\n",
       "      <td>0.864580</td>\n",
       "      <td>0.929828</td>\n",
       "      <td>0.092929</td>\n",
       "      <td>0.147007</td>\n",
       "      <td>0.473879</td>\n",
       "      <td>0.465427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.681100</td>\n",
       "      <td>0.846035</td>\n",
       "      <td>0.712661</td>\n",
       "      <td>0.846035</td>\n",
       "      <td>0.919802</td>\n",
       "      <td>0.101010</td>\n",
       "      <td>0.165303</td>\n",
       "      <td>0.511250</td>\n",
       "      <td>0.493092</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.665500</td>\n",
       "      <td>0.738454</td>\n",
       "      <td>0.655466</td>\n",
       "      <td>0.738454</td>\n",
       "      <td>0.859334</td>\n",
       "      <td>0.074747</td>\n",
       "      <td>0.271442</td>\n",
       "      <td>0.546157</td>\n",
       "      <td>0.515368</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.662000</td>\n",
       "      <td>0.734131</td>\n",
       "      <td>0.654790</td>\n",
       "      <td>0.734131</td>\n",
       "      <td>0.856814</td>\n",
       "      <td>0.119192</td>\n",
       "      <td>0.275708</td>\n",
       "      <td>0.531321</td>\n",
       "      <td>0.498739</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.613400</td>\n",
       "      <td>0.729969</td>\n",
       "      <td>0.641150</td>\n",
       "      <td>0.729969</td>\n",
       "      <td>0.854382</td>\n",
       "      <td>0.096970</td>\n",
       "      <td>0.279814</td>\n",
       "      <td>0.541580</td>\n",
       "      <td>0.506790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.623600</td>\n",
       "      <td>0.772686</td>\n",
       "      <td>0.662827</td>\n",
       "      <td>0.772686</td>\n",
       "      <td>0.879026</td>\n",
       "      <td>0.105051</td>\n",
       "      <td>0.237669</td>\n",
       "      <td>0.538107</td>\n",
       "      <td>0.506382</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.616300</td>\n",
       "      <td>0.758543</td>\n",
       "      <td>0.653704</td>\n",
       "      <td>0.758543</td>\n",
       "      <td>0.870944</td>\n",
       "      <td>0.109091</td>\n",
       "      <td>0.251622</td>\n",
       "      <td>0.551117</td>\n",
       "      <td>0.504970</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.566900</td>\n",
       "      <td>0.729422</td>\n",
       "      <td>0.637432</td>\n",
       "      <td>0.729422</td>\n",
       "      <td>0.854062</td>\n",
       "      <td>0.115152</td>\n",
       "      <td>0.280353</td>\n",
       "      <td>0.575441</td>\n",
       "      <td>0.534004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.618700</td>\n",
       "      <td>0.721346</td>\n",
       "      <td>0.637192</td>\n",
       "      <td>0.721346</td>\n",
       "      <td>0.849321</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.288321</td>\n",
       "      <td>0.567992</td>\n",
       "      <td>0.538577</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.530500</td>\n",
       "      <td>0.761122</td>\n",
       "      <td>0.648208</td>\n",
       "      <td>0.761122</td>\n",
       "      <td>0.872423</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.249079</td>\n",
       "      <td>0.556485</td>\n",
       "      <td>0.509667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.545100</td>\n",
       "      <td>0.710441</td>\n",
       "      <td>0.637389</td>\n",
       "      <td>0.710441</td>\n",
       "      <td>0.842877</td>\n",
       "      <td>0.105051</td>\n",
       "      <td>0.299080</td>\n",
       "      <td>0.564970</td>\n",
       "      <td>0.518566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.529700</td>\n",
       "      <td>0.718376</td>\n",
       "      <td>0.642403</td>\n",
       "      <td>0.718376</td>\n",
       "      <td>0.847570</td>\n",
       "      <td>0.098990</td>\n",
       "      <td>0.291252</td>\n",
       "      <td>0.564968</td>\n",
       "      <td>0.520691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.519000</td>\n",
       "      <td>0.749544</td>\n",
       "      <td>0.647484</td>\n",
       "      <td>0.749544</td>\n",
       "      <td>0.865762</td>\n",
       "      <td>0.098990</td>\n",
       "      <td>0.260501</td>\n",
       "      <td>0.576513</td>\n",
       "      <td>0.533361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.501400</td>\n",
       "      <td>0.669383</td>\n",
       "      <td>0.624203</td>\n",
       "      <td>0.669383</td>\n",
       "      <td>0.818158</td>\n",
       "      <td>0.098990</td>\n",
       "      <td>0.339588</td>\n",
       "      <td>0.584103</td>\n",
       "      <td>0.540852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.535300</td>\n",
       "      <td>0.696210</td>\n",
       "      <td>0.631282</td>\n",
       "      <td>0.696210</td>\n",
       "      <td>0.834392</td>\n",
       "      <td>0.107071</td>\n",
       "      <td>0.313120</td>\n",
       "      <td>0.573926</td>\n",
       "      <td>0.528437</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.475700</td>\n",
       "      <td>0.677368</td>\n",
       "      <td>0.637380</td>\n",
       "      <td>0.677368</td>\n",
       "      <td>0.823024</td>\n",
       "      <td>0.092929</td>\n",
       "      <td>0.331710</td>\n",
       "      <td>0.579759</td>\n",
       "      <td>0.545087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.510600</td>\n",
       "      <td>0.662792</td>\n",
       "      <td>0.628231</td>\n",
       "      <td>0.662792</td>\n",
       "      <td>0.814120</td>\n",
       "      <td>0.096970</td>\n",
       "      <td>0.346091</td>\n",
       "      <td>0.588827</td>\n",
       "      <td>0.537628</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.508600</td>\n",
       "      <td>0.668452</td>\n",
       "      <td>0.621163</td>\n",
       "      <td>0.668452</td>\n",
       "      <td>0.817589</td>\n",
       "      <td>0.101010</td>\n",
       "      <td>0.340507</td>\n",
       "      <td>0.594203</td>\n",
       "      <td>0.538698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.460100</td>\n",
       "      <td>0.675575</td>\n",
       "      <td>0.624443</td>\n",
       "      <td>0.675575</td>\n",
       "      <td>0.821934</td>\n",
       "      <td>0.105051</td>\n",
       "      <td>0.333479</td>\n",
       "      <td>0.589999</td>\n",
       "      <td>0.537421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.486800</td>\n",
       "      <td>0.677409</td>\n",
       "      <td>0.621528</td>\n",
       "      <td>0.677409</td>\n",
       "      <td>0.823048</td>\n",
       "      <td>0.109091</td>\n",
       "      <td>0.331670</td>\n",
       "      <td>0.588511</td>\n",
       "      <td>0.538836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.440700</td>\n",
       "      <td>0.672253</td>\n",
       "      <td>0.621203</td>\n",
       "      <td>0.672253</td>\n",
       "      <td>0.819910</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.336757</td>\n",
       "      <td>0.591303</td>\n",
       "      <td>0.539053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.474800</td>\n",
       "      <td>0.686742</td>\n",
       "      <td>0.622936</td>\n",
       "      <td>0.686742</td>\n",
       "      <td>0.828699</td>\n",
       "      <td>0.107071</td>\n",
       "      <td>0.322461</td>\n",
       "      <td>0.591759</td>\n",
       "      <td>0.537778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.469800</td>\n",
       "      <td>0.671171</td>\n",
       "      <td>0.620345</td>\n",
       "      <td>0.671171</td>\n",
       "      <td>0.819250</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.337824</td>\n",
       "      <td>0.593760</td>\n",
       "      <td>0.535539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.450000</td>\n",
       "      <td>0.669205</td>\n",
       "      <td>0.619685</td>\n",
       "      <td>0.669205</td>\n",
       "      <td>0.818050</td>\n",
       "      <td>0.101010</td>\n",
       "      <td>0.339763</td>\n",
       "      <td>0.594280</td>\n",
       "      <td>0.538263</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.448000</td>\n",
       "      <td>0.668664</td>\n",
       "      <td>0.619633</td>\n",
       "      <td>0.668664</td>\n",
       "      <td>0.817718</td>\n",
       "      <td>0.103030</td>\n",
       "      <td>0.340298</td>\n",
       "      <td>0.592834</td>\n",
       "      <td>0.540721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.406400</td>\n",
       "      <td>0.670567</td>\n",
       "      <td>0.619900</td>\n",
       "      <td>0.670567</td>\n",
       "      <td>0.818881</td>\n",
       "      <td>0.098990</td>\n",
       "      <td>0.338420</td>\n",
       "      <td>0.593474</td>\n",
       "      <td>0.539572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.494200</td>\n",
       "      <td>0.670147</td>\n",
       "      <td>0.619857</td>\n",
       "      <td>0.670147</td>\n",
       "      <td>0.818625</td>\n",
       "      <td>0.096970</td>\n",
       "      <td>0.338835</td>\n",
       "      <td>0.593874</td>\n",
       "      <td>0.539574</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.440500</td>\n",
       "      <td>0.669554</td>\n",
       "      <td>0.619812</td>\n",
       "      <td>0.669554</td>\n",
       "      <td>0.818263</td>\n",
       "      <td>0.103030</td>\n",
       "      <td>0.339419</td>\n",
       "      <td>0.593947</td>\n",
       "      <td>0.539533</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6360, training_loss=0.9494811025055699, metrics={'train_runtime': 1256.1459, 'train_samples_per_second': 70.677, 'train_steps_per_second': 5.063, 'total_flos': 8556453211226112.0, 'train_loss': 0.9494811025055699, 'epoch': 20.0})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d3bbea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
