{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('SemEvalWorkshop/humicroedit', 'subtask-1')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "model_name = \"answerdotai/ModernBERT-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 9652\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 2419\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Original Headline: {data_point['original']}. # Edited Headline: {data_point['edit']} # Output: The mean funniness score is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['meanGrade'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 9652\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  83%|████████▎ | 8000/9652 [00:00<00:00, 27226.06 examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9652/9652 [00:00<00:00, 26821.81 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'original', 'edit', 'grades', 'meanGrade', 'input']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS]# Original Headline: Dozens dead in possible gas <attack/> in Syria ; regime denies allegation. # Edited Headline: bloating # Output: The mean funniness score is[SEP]'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 2419\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "78"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.mask_token_id = tokenizer.mask_token_id\n",
    "config.num_labels = 1\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "159b238b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModernBertForSequenceClassification(\n",
       "  (model): ModernBertModel(\n",
       "    (embeddings): ModernBertEmbeddings(\n",
       "      (tok_embeddings): Embedding(50368, 768, padding_idx=50283)\n",
       "      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      (drop): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "    (layers): ModuleList(\n",
       "      (0): ModernBertEncoderLayer(\n",
       "        (attn_norm): Identity()\n",
       "        (attn): ModernBertAttention(\n",
       "          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "          (rotary_emb): ModernBertRotaryEmbedding()\n",
       "          (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "          (out_drop): Identity()\n",
       "        )\n",
       "        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): ModernBertMLP(\n",
       "          (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "          (act): GELUActivation()\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "          (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "        )\n",
       "      )\n",
       "      (1-21): 21 x ModernBertEncoderLayer(\n",
       "        (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): ModernBertAttention(\n",
       "          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "          (rotary_emb): ModernBertRotaryEmbedding()\n",
       "          (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "          (out_drop): Identity()\n",
       "        )\n",
       "        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): ModernBertMLP(\n",
       "          (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "          (act): GELUActivation()\n",
       "          (drop): Dropout(p=0.0, inplace=False)\n",
       "          (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (final_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (head): ModernBertPredictionHead(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=False)\n",
       "    (act): GELUActivation()\n",
       "    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (drop): Dropout(p=0.0, inplace=False)\n",
       "  (classifier): Linear(in_features=768, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6040' max='6040' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6040/6040 10:17, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.374300</td>\n",
       "      <td>0.361563</td>\n",
       "      <td>0.482212</td>\n",
       "      <td>0.361563</td>\n",
       "      <td>0.601301</td>\n",
       "      <td>0.127739</td>\n",
       "      <td>-0.080757</td>\n",
       "      <td>0.002940</td>\n",
       "      <td>-0.009014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.353000</td>\n",
       "      <td>0.343072</td>\n",
       "      <td>0.477026</td>\n",
       "      <td>0.343072</td>\n",
       "      <td>0.585723</td>\n",
       "      <td>0.119884</td>\n",
       "      <td>-0.025486</td>\n",
       "      <td>0.046815</td>\n",
       "      <td>0.028528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.352200</td>\n",
       "      <td>0.342425</td>\n",
       "      <td>0.480641</td>\n",
       "      <td>0.342425</td>\n",
       "      <td>0.585171</td>\n",
       "      <td>0.128152</td>\n",
       "      <td>-0.023551</td>\n",
       "      <td>0.089501</td>\n",
       "      <td>0.068624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.337800</td>\n",
       "      <td>0.332164</td>\n",
       "      <td>0.468751</td>\n",
       "      <td>0.332164</td>\n",
       "      <td>0.576337</td>\n",
       "      <td>0.127325</td>\n",
       "      <td>0.007118</td>\n",
       "      <td>0.133032</td>\n",
       "      <td>0.117517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.352500</td>\n",
       "      <td>0.332027</td>\n",
       "      <td>0.473047</td>\n",
       "      <td>0.332027</td>\n",
       "      <td>0.576218</td>\n",
       "      <td>0.133940</td>\n",
       "      <td>0.007529</td>\n",
       "      <td>0.157923</td>\n",
       "      <td>0.140181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.334900</td>\n",
       "      <td>0.327888</td>\n",
       "      <td>0.469955</td>\n",
       "      <td>0.327888</td>\n",
       "      <td>0.572615</td>\n",
       "      <td>0.127739</td>\n",
       "      <td>0.019902</td>\n",
       "      <td>0.175846</td>\n",
       "      <td>0.160333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.325700</td>\n",
       "      <td>0.334373</td>\n",
       "      <td>0.477708</td>\n",
       "      <td>0.334373</td>\n",
       "      <td>0.578250</td>\n",
       "      <td>0.119057</td>\n",
       "      <td>0.000518</td>\n",
       "      <td>0.186526</td>\n",
       "      <td>0.168336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.327200</td>\n",
       "      <td>0.326753</td>\n",
       "      <td>0.469679</td>\n",
       "      <td>0.326753</td>\n",
       "      <td>0.571623</td>\n",
       "      <td>0.129392</td>\n",
       "      <td>0.023293</td>\n",
       "      <td>0.200468</td>\n",
       "      <td>0.178536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.325300</td>\n",
       "      <td>0.325185</td>\n",
       "      <td>0.469654</td>\n",
       "      <td>0.325185</td>\n",
       "      <td>0.570250</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.027982</td>\n",
       "      <td>0.214870</td>\n",
       "      <td>0.188880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.333800</td>\n",
       "      <td>0.318494</td>\n",
       "      <td>0.463074</td>\n",
       "      <td>0.318494</td>\n",
       "      <td>0.564353</td>\n",
       "      <td>0.128979</td>\n",
       "      <td>0.047981</td>\n",
       "      <td>0.235456</td>\n",
       "      <td>0.207757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.326300</td>\n",
       "      <td>0.317014</td>\n",
       "      <td>0.461470</td>\n",
       "      <td>0.317014</td>\n",
       "      <td>0.563040</td>\n",
       "      <td>0.132286</td>\n",
       "      <td>0.052403</td>\n",
       "      <td>0.242203</td>\n",
       "      <td>0.217855</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.309300</td>\n",
       "      <td>0.318468</td>\n",
       "      <td>0.456574</td>\n",
       "      <td>0.318468</td>\n",
       "      <td>0.564329</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.048060</td>\n",
       "      <td>0.235748</td>\n",
       "      <td>0.212399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.318400</td>\n",
       "      <td>0.312351</td>\n",
       "      <td>0.455981</td>\n",
       "      <td>0.312351</td>\n",
       "      <td>0.558883</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.066344</td>\n",
       "      <td>0.259068</td>\n",
       "      <td>0.232926</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.307400</td>\n",
       "      <td>0.318878</td>\n",
       "      <td>0.465241</td>\n",
       "      <td>0.318878</td>\n",
       "      <td>0.564693</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>0.046832</td>\n",
       "      <td>0.254590</td>\n",
       "      <td>0.228604</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.321300</td>\n",
       "      <td>0.321590</td>\n",
       "      <td>0.456389</td>\n",
       "      <td>0.321590</td>\n",
       "      <td>0.567089</td>\n",
       "      <td>0.131046</td>\n",
       "      <td>0.038726</td>\n",
       "      <td>0.245931</td>\n",
       "      <td>0.218839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.324900</td>\n",
       "      <td>0.317715</td>\n",
       "      <td>0.455096</td>\n",
       "      <td>0.317715</td>\n",
       "      <td>0.563662</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.050309</td>\n",
       "      <td>0.254773</td>\n",
       "      <td>0.228562</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.284600</td>\n",
       "      <td>0.312300</td>\n",
       "      <td>0.453668</td>\n",
       "      <td>0.312300</td>\n",
       "      <td>0.558838</td>\n",
       "      <td>0.123605</td>\n",
       "      <td>0.066495</td>\n",
       "      <td>0.266530</td>\n",
       "      <td>0.241715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.302200</td>\n",
       "      <td>0.309087</td>\n",
       "      <td>0.453556</td>\n",
       "      <td>0.309087</td>\n",
       "      <td>0.555956</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>0.076098</td>\n",
       "      <td>0.279547</td>\n",
       "      <td>0.253384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.294100</td>\n",
       "      <td>0.322809</td>\n",
       "      <td>0.468058</td>\n",
       "      <td>0.322809</td>\n",
       "      <td>0.568163</td>\n",
       "      <td>0.121538</td>\n",
       "      <td>0.035083</td>\n",
       "      <td>0.268165</td>\n",
       "      <td>0.240084</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.309800</td>\n",
       "      <td>0.315096</td>\n",
       "      <td>0.461157</td>\n",
       "      <td>0.315096</td>\n",
       "      <td>0.561334</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>0.058137</td>\n",
       "      <td>0.280077</td>\n",
       "      <td>0.253905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.292700</td>\n",
       "      <td>0.310318</td>\n",
       "      <td>0.455390</td>\n",
       "      <td>0.310318</td>\n",
       "      <td>0.557062</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.072421</td>\n",
       "      <td>0.276405</td>\n",
       "      <td>0.246940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.298000</td>\n",
       "      <td>0.311556</td>\n",
       "      <td>0.450951</td>\n",
       "      <td>0.311556</td>\n",
       "      <td>0.558172</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.068720</td>\n",
       "      <td>0.288299</td>\n",
       "      <td>0.256469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.293600</td>\n",
       "      <td>0.310252</td>\n",
       "      <td>0.458090</td>\n",
       "      <td>0.310252</td>\n",
       "      <td>0.557003</td>\n",
       "      <td>0.128566</td>\n",
       "      <td>0.072618</td>\n",
       "      <td>0.291920</td>\n",
       "      <td>0.261096</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.300500</td>\n",
       "      <td>0.307552</td>\n",
       "      <td>0.450396</td>\n",
       "      <td>0.307552</td>\n",
       "      <td>0.554574</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>0.080688</td>\n",
       "      <td>0.291396</td>\n",
       "      <td>0.259007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.284700</td>\n",
       "      <td>0.306966</td>\n",
       "      <td>0.450105</td>\n",
       "      <td>0.306966</td>\n",
       "      <td>0.554045</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>0.082441</td>\n",
       "      <td>0.295662</td>\n",
       "      <td>0.261951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.299100</td>\n",
       "      <td>0.306472</td>\n",
       "      <td>0.452917</td>\n",
       "      <td>0.306472</td>\n",
       "      <td>0.553599</td>\n",
       "      <td>0.125672</td>\n",
       "      <td>0.083917</td>\n",
       "      <td>0.300538</td>\n",
       "      <td>0.266789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.279000</td>\n",
       "      <td>0.313354</td>\n",
       "      <td>0.449534</td>\n",
       "      <td>0.313354</td>\n",
       "      <td>0.559780</td>\n",
       "      <td>0.128566</td>\n",
       "      <td>0.063345</td>\n",
       "      <td>0.307883</td>\n",
       "      <td>0.273710</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.284000</td>\n",
       "      <td>0.328752</td>\n",
       "      <td>0.473545</td>\n",
       "      <td>0.328752</td>\n",
       "      <td>0.573369</td>\n",
       "      <td>0.131873</td>\n",
       "      <td>0.017319</td>\n",
       "      <td>0.309201</td>\n",
       "      <td>0.276718</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.284500</td>\n",
       "      <td>0.302980</td>\n",
       "      <td>0.447920</td>\n",
       "      <td>0.302980</td>\n",
       "      <td>0.550436</td>\n",
       "      <td>0.135180</td>\n",
       "      <td>0.094353</td>\n",
       "      <td>0.316863</td>\n",
       "      <td>0.284345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.286300</td>\n",
       "      <td>0.303538</td>\n",
       "      <td>0.450063</td>\n",
       "      <td>0.303538</td>\n",
       "      <td>0.550943</td>\n",
       "      <td>0.128566</td>\n",
       "      <td>0.092687</td>\n",
       "      <td>0.311490</td>\n",
       "      <td>0.281872</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.283600</td>\n",
       "      <td>0.304455</td>\n",
       "      <td>0.449295</td>\n",
       "      <td>0.304455</td>\n",
       "      <td>0.551775</td>\n",
       "      <td>0.130632</td>\n",
       "      <td>0.089944</td>\n",
       "      <td>0.309540</td>\n",
       "      <td>0.276449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.269700</td>\n",
       "      <td>0.310250</td>\n",
       "      <td>0.449695</td>\n",
       "      <td>0.310250</td>\n",
       "      <td>0.557001</td>\n",
       "      <td>0.120298</td>\n",
       "      <td>0.072622</td>\n",
       "      <td>0.304724</td>\n",
       "      <td>0.270047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.277300</td>\n",
       "      <td>0.310017</td>\n",
       "      <td>0.449706</td>\n",
       "      <td>0.310017</td>\n",
       "      <td>0.556792</td>\n",
       "      <td>0.127325</td>\n",
       "      <td>0.073318</td>\n",
       "      <td>0.303665</td>\n",
       "      <td>0.271018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.269400</td>\n",
       "      <td>0.311659</td>\n",
       "      <td>0.457660</td>\n",
       "      <td>0.311659</td>\n",
       "      <td>0.558264</td>\n",
       "      <td>0.127325</td>\n",
       "      <td>0.068413</td>\n",
       "      <td>0.317004</td>\n",
       "      <td>0.285516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.270300</td>\n",
       "      <td>0.309528</td>\n",
       "      <td>0.449292</td>\n",
       "      <td>0.309528</td>\n",
       "      <td>0.556353</td>\n",
       "      <td>0.129392</td>\n",
       "      <td>0.074780</td>\n",
       "      <td>0.312715</td>\n",
       "      <td>0.280529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.269900</td>\n",
       "      <td>0.305016</td>\n",
       "      <td>0.450974</td>\n",
       "      <td>0.305016</td>\n",
       "      <td>0.552282</td>\n",
       "      <td>0.128152</td>\n",
       "      <td>0.088269</td>\n",
       "      <td>0.314092</td>\n",
       "      <td>0.280403</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.274000</td>\n",
       "      <td>0.313133</td>\n",
       "      <td>0.459186</td>\n",
       "      <td>0.313133</td>\n",
       "      <td>0.559583</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.064006</td>\n",
       "      <td>0.311957</td>\n",
       "      <td>0.280711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.242700</td>\n",
       "      <td>0.306435</td>\n",
       "      <td>0.451075</td>\n",
       "      <td>0.306435</td>\n",
       "      <td>0.553566</td>\n",
       "      <td>0.128152</td>\n",
       "      <td>0.084026</td>\n",
       "      <td>0.321394</td>\n",
       "      <td>0.289349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.273600</td>\n",
       "      <td>0.307379</td>\n",
       "      <td>0.452775</td>\n",
       "      <td>0.307379</td>\n",
       "      <td>0.554418</td>\n",
       "      <td>0.119884</td>\n",
       "      <td>0.081205</td>\n",
       "      <td>0.312528</td>\n",
       "      <td>0.279818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.256800</td>\n",
       "      <td>0.312230</td>\n",
       "      <td>0.457377</td>\n",
       "      <td>0.312230</td>\n",
       "      <td>0.558776</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.066704</td>\n",
       "      <td>0.308715</td>\n",
       "      <td>0.273607</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.248400</td>\n",
       "      <td>0.313011</td>\n",
       "      <td>0.450521</td>\n",
       "      <td>0.313011</td>\n",
       "      <td>0.559473</td>\n",
       "      <td>0.131046</td>\n",
       "      <td>0.064371</td>\n",
       "      <td>0.311960</td>\n",
       "      <td>0.279034</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.271200</td>\n",
       "      <td>0.312086</td>\n",
       "      <td>0.457399</td>\n",
       "      <td>0.312086</td>\n",
       "      <td>0.558647</td>\n",
       "      <td>0.122365</td>\n",
       "      <td>0.067135</td>\n",
       "      <td>0.307537</td>\n",
       "      <td>0.273222</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.270300</td>\n",
       "      <td>0.309390</td>\n",
       "      <td>0.453560</td>\n",
       "      <td>0.309390</td>\n",
       "      <td>0.556228</td>\n",
       "      <td>0.128979</td>\n",
       "      <td>0.075194</td>\n",
       "      <td>0.312018</td>\n",
       "      <td>0.277303</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.230500</td>\n",
       "      <td>0.312022</td>\n",
       "      <td>0.451458</td>\n",
       "      <td>0.312022</td>\n",
       "      <td>0.558589</td>\n",
       "      <td>0.133113</td>\n",
       "      <td>0.067327</td>\n",
       "      <td>0.312669</td>\n",
       "      <td>0.279107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.252000</td>\n",
       "      <td>0.312212</td>\n",
       "      <td>0.451382</td>\n",
       "      <td>0.312212</td>\n",
       "      <td>0.558759</td>\n",
       "      <td>0.128566</td>\n",
       "      <td>0.066759</td>\n",
       "      <td>0.310058</td>\n",
       "      <td>0.276469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.260100</td>\n",
       "      <td>0.308938</td>\n",
       "      <td>0.452728</td>\n",
       "      <td>0.308938</td>\n",
       "      <td>0.555822</td>\n",
       "      <td>0.129392</td>\n",
       "      <td>0.076545</td>\n",
       "      <td>0.311899</td>\n",
       "      <td>0.278532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.251500</td>\n",
       "      <td>0.310360</td>\n",
       "      <td>0.450798</td>\n",
       "      <td>0.310360</td>\n",
       "      <td>0.557100</td>\n",
       "      <td>0.131873</td>\n",
       "      <td>0.072293</td>\n",
       "      <td>0.314066</td>\n",
       "      <td>0.280734</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.248400</td>\n",
       "      <td>0.311085</td>\n",
       "      <td>0.454337</td>\n",
       "      <td>0.311085</td>\n",
       "      <td>0.557750</td>\n",
       "      <td>0.124432</td>\n",
       "      <td>0.070127</td>\n",
       "      <td>0.313181</td>\n",
       "      <td>0.279375</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.245200</td>\n",
       "      <td>0.311686</td>\n",
       "      <td>0.452200</td>\n",
       "      <td>0.311686</td>\n",
       "      <td>0.558288</td>\n",
       "      <td>0.135593</td>\n",
       "      <td>0.068331</td>\n",
       "      <td>0.312783</td>\n",
       "      <td>0.278177</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.238400</td>\n",
       "      <td>0.316816</td>\n",
       "      <td>0.460601</td>\n",
       "      <td>0.316816</td>\n",
       "      <td>0.562864</td>\n",
       "      <td>0.119471</td>\n",
       "      <td>0.052995</td>\n",
       "      <td>0.311460</td>\n",
       "      <td>0.277586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.253200</td>\n",
       "      <td>0.312655</td>\n",
       "      <td>0.455304</td>\n",
       "      <td>0.312655</td>\n",
       "      <td>0.559156</td>\n",
       "      <td>0.128979</td>\n",
       "      <td>0.065433</td>\n",
       "      <td>0.308520</td>\n",
       "      <td>0.273362</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.230000</td>\n",
       "      <td>0.312868</td>\n",
       "      <td>0.455692</td>\n",
       "      <td>0.312868</td>\n",
       "      <td>0.559346</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>0.064798</td>\n",
       "      <td>0.307940</td>\n",
       "      <td>0.273539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.241600</td>\n",
       "      <td>0.313350</td>\n",
       "      <td>0.453147</td>\n",
       "      <td>0.313350</td>\n",
       "      <td>0.559777</td>\n",
       "      <td>0.132699</td>\n",
       "      <td>0.063357</td>\n",
       "      <td>0.308809</td>\n",
       "      <td>0.274483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.249600</td>\n",
       "      <td>0.312775</td>\n",
       "      <td>0.454806</td>\n",
       "      <td>0.312775</td>\n",
       "      <td>0.559263</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>0.065074</td>\n",
       "      <td>0.311287</td>\n",
       "      <td>0.277507</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.239800</td>\n",
       "      <td>0.312903</td>\n",
       "      <td>0.453738</td>\n",
       "      <td>0.312903</td>\n",
       "      <td>0.559377</td>\n",
       "      <td>0.130219</td>\n",
       "      <td>0.064693</td>\n",
       "      <td>0.311177</td>\n",
       "      <td>0.277412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.247900</td>\n",
       "      <td>0.313444</td>\n",
       "      <td>0.455699</td>\n",
       "      <td>0.313444</td>\n",
       "      <td>0.559861</td>\n",
       "      <td>0.124018</td>\n",
       "      <td>0.063076</td>\n",
       "      <td>0.310951</td>\n",
       "      <td>0.276912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.227400</td>\n",
       "      <td>0.313180</td>\n",
       "      <td>0.454067</td>\n",
       "      <td>0.313180</td>\n",
       "      <td>0.559625</td>\n",
       "      <td>0.129392</td>\n",
       "      <td>0.063864</td>\n",
       "      <td>0.310381</td>\n",
       "      <td>0.276178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.242700</td>\n",
       "      <td>0.313197</td>\n",
       "      <td>0.454457</td>\n",
       "      <td>0.313197</td>\n",
       "      <td>0.559640</td>\n",
       "      <td>0.127325</td>\n",
       "      <td>0.063814</td>\n",
       "      <td>0.310256</td>\n",
       "      <td>0.276122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.248800</td>\n",
       "      <td>0.313216</td>\n",
       "      <td>0.454711</td>\n",
       "      <td>0.313216</td>\n",
       "      <td>0.559657</td>\n",
       "      <td>0.126085</td>\n",
       "      <td>0.063758</td>\n",
       "      <td>0.310383</td>\n",
       "      <td>0.276283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.229000</td>\n",
       "      <td>0.313230</td>\n",
       "      <td>0.454768</td>\n",
       "      <td>0.313230</td>\n",
       "      <td>0.559669</td>\n",
       "      <td>0.126499</td>\n",
       "      <td>0.063717</td>\n",
       "      <td>0.310401</td>\n",
       "      <td>0.276325</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6040, training_loss=0.2835080877834598, metrics={'train_runtime': 625.8194, 'train_samples_per_second': 154.23, 'train_steps_per_second': 9.651, 'total_flos': 13111580233440.0, 'train_loss': 0.2835080877834598, 'epoch': 10.0})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4e83df",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
