{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 2470\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filter: 100%|██████████| 2470/2470 [00:00<00:00, 42191.92 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 2466\n",
      "})\n",
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 1972\n",
      "})\n",
      "Test Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 494\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "cols = ['id', 'text', 'label', 'intensity']\n",
    "path = \"https://raw.githubusercontent.com/vinayakumarr/WASSA-2017/refs/heads/master/wassa/data/training/\"\n",
    "anger_train = pd.read_csv(StringIO(requests.get(path + 'anger-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "fear_train = pd.read_csv(StringIO(requests.get(path + 'fear-ratings-0to1.train').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "sad_train = pd.read_csv(StringIO(requests.get(path + 'sadness-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "joy_train = pd.read_csv(StringIO(requests.get(path + 'joy-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "\n",
    "dataset = pd.concat([anger_train, fear_train, sad_train, joy_train], axis=0)\n",
    "\n",
    "# Reset index for the combined DataFrame (optional)\n",
    "dataset.reset_index(inplace=True)\n",
    "\n",
    "from datasets import Dataset\n",
    "import pandas as pd\n",
    "dataset = Dataset.from_pandas(dataset)\n",
    "\n",
    "\n",
    "# Shuffle the dataset\n",
    "dataset = dataset.shuffle(seed=42)\n",
    "\n",
    "# Inspect the dataset\n",
    "print(dataset)\n",
    "\n",
    "def is_valid_intensity(example):\n",
    "    if example['intensity'] is not None:\n",
    "        #print(example['intensity'])\n",
    "        try: \n",
    "            k = float(example['intensity'])\n",
    "            return True\n",
    "        except:\n",
    "        \n",
    "            return False\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "# Filter the dataset\n",
    "dataset = dataset.filter(is_valid_intensity)\n",
    "print(dataset)\n",
    "# Split the shuffled dataset into train and test sets\n",
    "train_test_split = dataset.train_test_split(test_size=0.2, seed=42)\n",
    "\n",
    "# Access the train and test datasets\n",
    "train_data = train_test_split['train']\n",
    "val_data = train_test_split['test']\n",
    "\n",
    "# Inspect the datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Test Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n",
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig, AutoModelForSequenceClassification\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "# from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name = \"HuggingFaceTB/SmolLM2-135M\"\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1972/1972 [00:00<00:00, 10853.28 examples/s]\n",
      "Map: 100%|██████████| 494/494 [00:00<00:00, 8876.00 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity', 'labels', 'input'],\n",
      "    num_rows: 1972\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity', 'labels', 'input'],\n",
      "    num_rows: 494\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Input: {data_point['text']} # Label: {data_point['label']} # Output: The intensity is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['intensity'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "train_data = train_data.map(add_label_column)\n",
    "val_data = val_data.map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1972/1972 [00:00<00:00, 22284.76 examples/s]\n",
      "Map: 100%|██████████| 494/494 [00:00<00:00, 21095.14 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete = ['label', 'intensity','id', 'text']  # Update as per your dataset\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# Input: @TehShockwave turn that grumpy frown upside-down\\\\n\\\\nYou did something next to impossible today # Label: sadness # Output: The intensity is'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=5e-4,\n",
    "    per_device_train_batch_size=8,\n",
    "    per_device_eval_batch_size=8,\n",
    "    gradient_accumulation_steps= 1,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2470' max='2470' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2470/2470 06:18, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.268900</td>\n",
       "      <td>0.070669</td>\n",
       "      <td>0.215945</td>\n",
       "      <td>0.070669</td>\n",
       "      <td>0.265837</td>\n",
       "      <td>0.271255</td>\n",
       "      <td>-0.984989</td>\n",
       "      <td>0.199836</td>\n",
       "      <td>0.198185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.057200</td>\n",
       "      <td>0.075093</td>\n",
       "      <td>0.229594</td>\n",
       "      <td>0.075093</td>\n",
       "      <td>0.274032</td>\n",
       "      <td>0.226721</td>\n",
       "      <td>-1.109252</td>\n",
       "      <td>0.288488</td>\n",
       "      <td>0.288871</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.040600</td>\n",
       "      <td>0.059492</td>\n",
       "      <td>0.203394</td>\n",
       "      <td>0.059492</td>\n",
       "      <td>0.243911</td>\n",
       "      <td>0.273279</td>\n",
       "      <td>-0.671048</td>\n",
       "      <td>0.386607</td>\n",
       "      <td>0.388792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.040000</td>\n",
       "      <td>0.030909</td>\n",
       "      <td>0.143327</td>\n",
       "      <td>0.030909</td>\n",
       "      <td>0.175809</td>\n",
       "      <td>0.417004</td>\n",
       "      <td>0.131817</td>\n",
       "      <td>0.399279</td>\n",
       "      <td>0.373741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.037500</td>\n",
       "      <td>0.045106</td>\n",
       "      <td>0.170517</td>\n",
       "      <td>0.045106</td>\n",
       "      <td>0.212382</td>\n",
       "      <td>0.354251</td>\n",
       "      <td>-0.266958</td>\n",
       "      <td>0.492493</td>\n",
       "      <td>0.477385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.036200</td>\n",
       "      <td>0.027479</td>\n",
       "      <td>0.132382</td>\n",
       "      <td>0.027479</td>\n",
       "      <td>0.165769</td>\n",
       "      <td>0.451417</td>\n",
       "      <td>0.228149</td>\n",
       "      <td>0.535571</td>\n",
       "      <td>0.507853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.031200</td>\n",
       "      <td>0.046790</td>\n",
       "      <td>0.180173</td>\n",
       "      <td>0.046790</td>\n",
       "      <td>0.216309</td>\n",
       "      <td>0.307692</td>\n",
       "      <td>-0.314246</td>\n",
       "      <td>0.560398</td>\n",
       "      <td>0.554342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.035400</td>\n",
       "      <td>0.021108</td>\n",
       "      <td>0.116782</td>\n",
       "      <td>0.021108</td>\n",
       "      <td>0.145287</td>\n",
       "      <td>0.489879</td>\n",
       "      <td>0.407100</td>\n",
       "      <td>0.644532</td>\n",
       "      <td>0.637611</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.019300</td>\n",
       "      <td>0.020517</td>\n",
       "      <td>0.115064</td>\n",
       "      <td>0.020517</td>\n",
       "      <td>0.143239</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.423700</td>\n",
       "      <td>0.654396</td>\n",
       "      <td>0.644597</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.019800</td>\n",
       "      <td>0.021826</td>\n",
       "      <td>0.120500</td>\n",
       "      <td>0.021826</td>\n",
       "      <td>0.147737</td>\n",
       "      <td>0.461538</td>\n",
       "      <td>0.386940</td>\n",
       "      <td>0.671096</td>\n",
       "      <td>0.649337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.015000</td>\n",
       "      <td>0.021753</td>\n",
       "      <td>0.120849</td>\n",
       "      <td>0.021753</td>\n",
       "      <td>0.147489</td>\n",
       "      <td>0.467611</td>\n",
       "      <td>0.388990</td>\n",
       "      <td>0.679946</td>\n",
       "      <td>0.679030</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.016500</td>\n",
       "      <td>0.035466</td>\n",
       "      <td>0.152992</td>\n",
       "      <td>0.035466</td>\n",
       "      <td>0.188325</td>\n",
       "      <td>0.378543</td>\n",
       "      <td>0.003808</td>\n",
       "      <td>0.680246</td>\n",
       "      <td>0.665702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.014000</td>\n",
       "      <td>0.020585</td>\n",
       "      <td>0.112627</td>\n",
       "      <td>0.020585</td>\n",
       "      <td>0.143474</td>\n",
       "      <td>0.528340</td>\n",
       "      <td>0.421808</td>\n",
       "      <td>0.698358</td>\n",
       "      <td>0.688715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.012500</td>\n",
       "      <td>0.019710</td>\n",
       "      <td>0.114166</td>\n",
       "      <td>0.019710</td>\n",
       "      <td>0.140393</td>\n",
       "      <td>0.495951</td>\n",
       "      <td>0.446369</td>\n",
       "      <td>0.699895</td>\n",
       "      <td>0.692316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.010900</td>\n",
       "      <td>0.017640</td>\n",
       "      <td>0.106111</td>\n",
       "      <td>0.017640</td>\n",
       "      <td>0.132817</td>\n",
       "      <td>0.542510</td>\n",
       "      <td>0.504509</td>\n",
       "      <td>0.713037</td>\n",
       "      <td>0.702127</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.012900</td>\n",
       "      <td>0.019743</td>\n",
       "      <td>0.111515</td>\n",
       "      <td>0.019743</td>\n",
       "      <td>0.140511</td>\n",
       "      <td>0.530364</td>\n",
       "      <td>0.445439</td>\n",
       "      <td>0.700964</td>\n",
       "      <td>0.691137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.009200</td>\n",
       "      <td>0.017539</td>\n",
       "      <td>0.107072</td>\n",
       "      <td>0.017539</td>\n",
       "      <td>0.132437</td>\n",
       "      <td>0.562753</td>\n",
       "      <td>0.507344</td>\n",
       "      <td>0.716425</td>\n",
       "      <td>0.707567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.008400</td>\n",
       "      <td>0.018534</td>\n",
       "      <td>0.106842</td>\n",
       "      <td>0.018534</td>\n",
       "      <td>0.136141</td>\n",
       "      <td>0.558704</td>\n",
       "      <td>0.479400</td>\n",
       "      <td>0.726834</td>\n",
       "      <td>0.716491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.007500</td>\n",
       "      <td>0.018897</td>\n",
       "      <td>0.108331</td>\n",
       "      <td>0.018897</td>\n",
       "      <td>0.137465</td>\n",
       "      <td>0.558704</td>\n",
       "      <td>0.469224</td>\n",
       "      <td>0.721268</td>\n",
       "      <td>0.710053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.007900</td>\n",
       "      <td>0.017485</td>\n",
       "      <td>0.106347</td>\n",
       "      <td>0.017485</td>\n",
       "      <td>0.132229</td>\n",
       "      <td>0.554656</td>\n",
       "      <td>0.508885</td>\n",
       "      <td>0.720067</td>\n",
       "      <td>0.710753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.005700</td>\n",
       "      <td>0.017425</td>\n",
       "      <td>0.105596</td>\n",
       "      <td>0.017425</td>\n",
       "      <td>0.132003</td>\n",
       "      <td>0.550607</td>\n",
       "      <td>0.510567</td>\n",
       "      <td>0.725617</td>\n",
       "      <td>0.717064</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.006200</td>\n",
       "      <td>0.017336</td>\n",
       "      <td>0.105478</td>\n",
       "      <td>0.017336</td>\n",
       "      <td>0.131665</td>\n",
       "      <td>0.558704</td>\n",
       "      <td>0.513068</td>\n",
       "      <td>0.723671</td>\n",
       "      <td>0.715553</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.005400</td>\n",
       "      <td>0.017635</td>\n",
       "      <td>0.105462</td>\n",
       "      <td>0.017635</td>\n",
       "      <td>0.132796</td>\n",
       "      <td>0.560729</td>\n",
       "      <td>0.504669</td>\n",
       "      <td>0.721176</td>\n",
       "      <td>0.711072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.005200</td>\n",
       "      <td>0.017485</td>\n",
       "      <td>0.105623</td>\n",
       "      <td>0.017485</td>\n",
       "      <td>0.132232</td>\n",
       "      <td>0.564777</td>\n",
       "      <td>0.508865</td>\n",
       "      <td>0.721856</td>\n",
       "      <td>0.711920</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=2470, training_loss=0.029466515719166652, metrics={'train_runtime': 378.7954, 'train_samples_per_second': 52.06, 'train_steps_per_second': 6.521, 'total_flos': 3155610147840.0, 'train_loss': 0.029466515719166652, 'epoch': 10.0})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0108f3e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
