{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Downloading data: 100%|██████████| 1.62M/1.62M [00:00<00:00, 9.80MB/s]\n",
      "Generating train split: 100%|██████████| 9652/9652 [00:00<00:00, 37601.88 examples/s]\n",
      "Generating test split: 100%|██████████| 3024/3024 [00:00<00:00, 34724.60 examples/s]\n",
      "Generating validation split: 100%|██████████| 2419/2419 [00:00<00:00, 35566.88 examples/s]\n",
      "Generating funlines split: 100%|██████████| 8248/8248 [00:00<00:00, 37463.58 examples/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('SemEvalWorkshop/humicroedit', 'subtask-1')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9652/9652 [00:00<00:00, 15533.95 examples/s]\n",
      "Map: 100%|██████████| 2419/2419 [00:00<00:00, 15694.93 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 9652\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 2419\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Original Headline: {data_point['original']}. # Edited Headline: {data_point['edit']} # Output: The mean funniness score is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['meanGrade'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 9652\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  73%|███████▎  | 7000/9652 [00:00<00:00, 24132.29 examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9652/9652 [00:00<00:00, 16962.77 examples/s]\n",
      "Map: 100%|██████████| 2419/2419 [00:00<00:00, 23279.82 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'original', 'edit', 'grades', 'meanGrade', 'input']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] # Original Headline: Dozens dead in possible gas <attack/> in Syria ; regime denies allegation. # Edited Headline: bloating # Output: The mean funniness score is[SEP]'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 2419\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "# from modeling import MLMSequenceClassification\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.mask_token_id = tokenizer.mask_token_id\n",
    "config.num_labels = 1\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "159b238b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DebertaV2ForSequenceClassification(\n",
       "  (deberta): DebertaV2Model(\n",
       "    (embeddings): DebertaV2Embeddings(\n",
       "      (word_embeddings): Embedding(128100, 768, padding_idx=0)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): DebertaV2Encoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x DebertaV2Layer(\n",
       "          (attention): DebertaV2Attention(\n",
       "            (self): DisentangledSelfAttention(\n",
       "              (query_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (pos_dropout): Dropout(p=0.1, inplace=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): DebertaV2SelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): DebertaV2Intermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): DebertaV2Output(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (rel_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (pooler): ContextPooler(\n",
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (dropout): Dropout(p=0, inplace=False)\n",
       "  )\n",
       "  (classifier): Linear(in_features=768, out_features=1, bias=True)\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6040' max='6040' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6040/6040 10:31, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.562800</td>\n",
       "      <td>0.357727</td>\n",
       "      <td>0.480622</td>\n",
       "      <td>0.357727</td>\n",
       "      <td>0.598103</td>\n",
       "      <td>0.131459</td>\n",
       "      <td>-0.069293</td>\n",
       "      <td>0.007580</td>\n",
       "      <td>-0.009517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.360100</td>\n",
       "      <td>0.337524</td>\n",
       "      <td>0.475128</td>\n",
       "      <td>0.337524</td>\n",
       "      <td>0.580969</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>-0.008904</td>\n",
       "      <td>0.054009</td>\n",
       "      <td>0.101394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.354600</td>\n",
       "      <td>0.343872</td>\n",
       "      <td>0.484414</td>\n",
       "      <td>0.343872</td>\n",
       "      <td>0.586406</td>\n",
       "      <td>0.120711</td>\n",
       "      <td>-0.027878</td>\n",
       "      <td>0.052548</td>\n",
       "      <td>0.052492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.352600</td>\n",
       "      <td>0.327242</td>\n",
       "      <td>0.469167</td>\n",
       "      <td>0.327242</td>\n",
       "      <td>0.572051</td>\n",
       "      <td>0.125672</td>\n",
       "      <td>0.021832</td>\n",
       "      <td>0.180424</td>\n",
       "      <td>0.165235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.356100</td>\n",
       "      <td>0.361905</td>\n",
       "      <td>0.502217</td>\n",
       "      <td>0.361905</td>\n",
       "      <td>0.601586</td>\n",
       "      <td>0.109549</td>\n",
       "      <td>-0.081782</td>\n",
       "      <td>0.197774</td>\n",
       "      <td>0.211751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.343600</td>\n",
       "      <td>0.358618</td>\n",
       "      <td>0.499657</td>\n",
       "      <td>0.358618</td>\n",
       "      <td>0.598847</td>\n",
       "      <td>0.108309</td>\n",
       "      <td>-0.071954</td>\n",
       "      <td>0.195424</td>\n",
       "      <td>0.226457</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.342400</td>\n",
       "      <td>0.333881</td>\n",
       "      <td>0.477608</td>\n",
       "      <td>0.333881</td>\n",
       "      <td>0.577824</td>\n",
       "      <td>0.120298</td>\n",
       "      <td>0.001988</td>\n",
       "      <td>0.196861</td>\n",
       "      <td>0.196511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.331400</td>\n",
       "      <td>0.326635</td>\n",
       "      <td>0.473547</td>\n",
       "      <td>0.326635</td>\n",
       "      <td>0.571520</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>0.023645</td>\n",
       "      <td>0.300312</td>\n",
       "      <td>0.301599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.328600</td>\n",
       "      <td>0.333295</td>\n",
       "      <td>0.479549</td>\n",
       "      <td>0.333295</td>\n",
       "      <td>0.577317</td>\n",
       "      <td>0.117817</td>\n",
       "      <td>0.003737</td>\n",
       "      <td>0.302363</td>\n",
       "      <td>0.309395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.336600</td>\n",
       "      <td>0.339596</td>\n",
       "      <td>0.485026</td>\n",
       "      <td>0.339596</td>\n",
       "      <td>0.582748</td>\n",
       "      <td>0.119057</td>\n",
       "      <td>-0.015095</td>\n",
       "      <td>0.298808</td>\n",
       "      <td>0.310726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.329400</td>\n",
       "      <td>0.320240</td>\n",
       "      <td>0.469015</td>\n",
       "      <td>0.320240</td>\n",
       "      <td>0.565898</td>\n",
       "      <td>0.126085</td>\n",
       "      <td>0.042761</td>\n",
       "      <td>0.334186</td>\n",
       "      <td>0.341135</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.309100</td>\n",
       "      <td>0.326890</td>\n",
       "      <td>0.475022</td>\n",
       "      <td>0.326890</td>\n",
       "      <td>0.571743</td>\n",
       "      <td>0.122365</td>\n",
       "      <td>0.022885</td>\n",
       "      <td>0.341125</td>\n",
       "      <td>0.348398</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.315900</td>\n",
       "      <td>0.356775</td>\n",
       "      <td>0.497581</td>\n",
       "      <td>0.356775</td>\n",
       "      <td>0.597307</td>\n",
       "      <td>0.115337</td>\n",
       "      <td>-0.066447</td>\n",
       "      <td>0.343121</td>\n",
       "      <td>0.355379</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.294600</td>\n",
       "      <td>0.344659</td>\n",
       "      <td>0.487914</td>\n",
       "      <td>0.344659</td>\n",
       "      <td>0.587077</td>\n",
       "      <td>0.126085</td>\n",
       "      <td>-0.030231</td>\n",
       "      <td>0.356270</td>\n",
       "      <td>0.364508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.320300</td>\n",
       "      <td>0.287337</td>\n",
       "      <td>0.433080</td>\n",
       "      <td>0.287337</td>\n",
       "      <td>0.536038</td>\n",
       "      <td>0.133113</td>\n",
       "      <td>0.141113</td>\n",
       "      <td>0.378392</td>\n",
       "      <td>0.379918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.315800</td>\n",
       "      <td>0.296867</td>\n",
       "      <td>0.447402</td>\n",
       "      <td>0.296867</td>\n",
       "      <td>0.544855</td>\n",
       "      <td>0.132286</td>\n",
       "      <td>0.112628</td>\n",
       "      <td>0.374985</td>\n",
       "      <td>0.378995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.278200</td>\n",
       "      <td>0.293083</td>\n",
       "      <td>0.443957</td>\n",
       "      <td>0.293083</td>\n",
       "      <td>0.541372</td>\n",
       "      <td>0.131046</td>\n",
       "      <td>0.123937</td>\n",
       "      <td>0.376248</td>\n",
       "      <td>0.380697</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.292500</td>\n",
       "      <td>0.339816</td>\n",
       "      <td>0.481944</td>\n",
       "      <td>0.339816</td>\n",
       "      <td>0.582937</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>-0.015752</td>\n",
       "      <td>0.375834</td>\n",
       "      <td>0.383033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.279100</td>\n",
       "      <td>0.395813</td>\n",
       "      <td>0.527470</td>\n",
       "      <td>0.395813</td>\n",
       "      <td>0.629136</td>\n",
       "      <td>0.111616</td>\n",
       "      <td>-0.183135</td>\n",
       "      <td>0.380868</td>\n",
       "      <td>0.386106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.294400</td>\n",
       "      <td>0.383224</td>\n",
       "      <td>0.515308</td>\n",
       "      <td>0.383224</td>\n",
       "      <td>0.619051</td>\n",
       "      <td>0.109549</td>\n",
       "      <td>-0.145505</td>\n",
       "      <td>0.372499</td>\n",
       "      <td>0.378501</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.280700</td>\n",
       "      <td>0.294093</td>\n",
       "      <td>0.443673</td>\n",
       "      <td>0.294093</td>\n",
       "      <td>0.542304</td>\n",
       "      <td>0.130632</td>\n",
       "      <td>0.120917</td>\n",
       "      <td>0.396285</td>\n",
       "      <td>0.393124</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.289200</td>\n",
       "      <td>0.297150</td>\n",
       "      <td>0.449041</td>\n",
       "      <td>0.297151</td>\n",
       "      <td>0.545115</td>\n",
       "      <td>0.128979</td>\n",
       "      <td>0.111779</td>\n",
       "      <td>0.406440</td>\n",
       "      <td>0.404736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.282100</td>\n",
       "      <td>0.326436</td>\n",
       "      <td>0.475337</td>\n",
       "      <td>0.326436</td>\n",
       "      <td>0.571345</td>\n",
       "      <td>0.122365</td>\n",
       "      <td>0.024242</td>\n",
       "      <td>0.405133</td>\n",
       "      <td>0.411197</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.296600</td>\n",
       "      <td>0.311931</td>\n",
       "      <td>0.463096</td>\n",
       "      <td>0.311931</td>\n",
       "      <td>0.558508</td>\n",
       "      <td>0.127739</td>\n",
       "      <td>0.067598</td>\n",
       "      <td>0.414918</td>\n",
       "      <td>0.417715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.261900</td>\n",
       "      <td>0.292395</td>\n",
       "      <td>0.441337</td>\n",
       "      <td>0.292394</td>\n",
       "      <td>0.540735</td>\n",
       "      <td>0.133526</td>\n",
       "      <td>0.125996</td>\n",
       "      <td>0.415523</td>\n",
       "      <td>0.418392</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.274900</td>\n",
       "      <td>0.310981</td>\n",
       "      <td>0.459930</td>\n",
       "      <td>0.310981</td>\n",
       "      <td>0.557657</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>0.070437</td>\n",
       "      <td>0.412664</td>\n",
       "      <td>0.411689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.266200</td>\n",
       "      <td>0.277561</td>\n",
       "      <td>0.424355</td>\n",
       "      <td>0.277561</td>\n",
       "      <td>0.526840</td>\n",
       "      <td>0.139314</td>\n",
       "      <td>0.170336</td>\n",
       "      <td>0.423544</td>\n",
       "      <td>0.419850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.269400</td>\n",
       "      <td>0.365742</td>\n",
       "      <td>0.504723</td>\n",
       "      <td>0.365742</td>\n",
       "      <td>0.604766</td>\n",
       "      <td>0.114097</td>\n",
       "      <td>-0.093248</td>\n",
       "      <td>0.405798</td>\n",
       "      <td>0.408361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.265600</td>\n",
       "      <td>0.340625</td>\n",
       "      <td>0.484279</td>\n",
       "      <td>0.340625</td>\n",
       "      <td>0.583631</td>\n",
       "      <td>0.122778</td>\n",
       "      <td>-0.018171</td>\n",
       "      <td>0.419048</td>\n",
       "      <td>0.417946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.266800</td>\n",
       "      <td>0.347162</td>\n",
       "      <td>0.488704</td>\n",
       "      <td>0.347162</td>\n",
       "      <td>0.589205</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>-0.037713</td>\n",
       "      <td>0.415452</td>\n",
       "      <td>0.414500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.261200</td>\n",
       "      <td>0.329054</td>\n",
       "      <td>0.474320</td>\n",
       "      <td>0.329054</td>\n",
       "      <td>0.573632</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.016417</td>\n",
       "      <td>0.413065</td>\n",
       "      <td>0.413089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.255500</td>\n",
       "      <td>0.283794</td>\n",
       "      <td>0.433467</td>\n",
       "      <td>0.283794</td>\n",
       "      <td>0.532723</td>\n",
       "      <td>0.134766</td>\n",
       "      <td>0.151704</td>\n",
       "      <td>0.420404</td>\n",
       "      <td>0.418974</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.271800</td>\n",
       "      <td>0.308820</td>\n",
       "      <td>0.458718</td>\n",
       "      <td>0.308820</td>\n",
       "      <td>0.555716</td>\n",
       "      <td>0.129392</td>\n",
       "      <td>0.076897</td>\n",
       "      <td>0.423527</td>\n",
       "      <td>0.419581</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.249500</td>\n",
       "      <td>0.358793</td>\n",
       "      <td>0.494407</td>\n",
       "      <td>0.358793</td>\n",
       "      <td>0.598994</td>\n",
       "      <td>0.112857</td>\n",
       "      <td>-0.072479</td>\n",
       "      <td>0.419675</td>\n",
       "      <td>0.418974</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.256900</td>\n",
       "      <td>0.301394</td>\n",
       "      <td>0.450924</td>\n",
       "      <td>0.301394</td>\n",
       "      <td>0.548994</td>\n",
       "      <td>0.133940</td>\n",
       "      <td>0.099094</td>\n",
       "      <td>0.424078</td>\n",
       "      <td>0.418217</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.248900</td>\n",
       "      <td>0.338376</td>\n",
       "      <td>0.481955</td>\n",
       "      <td>0.338376</td>\n",
       "      <td>0.581701</td>\n",
       "      <td>0.117817</td>\n",
       "      <td>-0.011450</td>\n",
       "      <td>0.427821</td>\n",
       "      <td>0.424987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.260700</td>\n",
       "      <td>0.353654</td>\n",
       "      <td>0.494863</td>\n",
       "      <td>0.353654</td>\n",
       "      <td>0.594688</td>\n",
       "      <td>0.110790</td>\n",
       "      <td>-0.057117</td>\n",
       "      <td>0.419366</td>\n",
       "      <td>0.418345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.235800</td>\n",
       "      <td>0.332969</td>\n",
       "      <td>0.476800</td>\n",
       "      <td>0.332969</td>\n",
       "      <td>0.577035</td>\n",
       "      <td>0.112443</td>\n",
       "      <td>0.004712</td>\n",
       "      <td>0.426167</td>\n",
       "      <td>0.423400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.243600</td>\n",
       "      <td>0.326841</td>\n",
       "      <td>0.471816</td>\n",
       "      <td>0.326841</td>\n",
       "      <td>0.571700</td>\n",
       "      <td>0.119884</td>\n",
       "      <td>0.023029</td>\n",
       "      <td>0.422430</td>\n",
       "      <td>0.419743</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.239300</td>\n",
       "      <td>0.356700</td>\n",
       "      <td>0.495919</td>\n",
       "      <td>0.356700</td>\n",
       "      <td>0.597244</td>\n",
       "      <td>0.114510</td>\n",
       "      <td>-0.066223</td>\n",
       "      <td>0.421684</td>\n",
       "      <td>0.420115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.246000</td>\n",
       "      <td>0.297662</td>\n",
       "      <td>0.446521</td>\n",
       "      <td>0.297662</td>\n",
       "      <td>0.545584</td>\n",
       "      <td>0.132286</td>\n",
       "      <td>0.110252</td>\n",
       "      <td>0.426397</td>\n",
       "      <td>0.420390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.270700</td>\n",
       "      <td>0.353787</td>\n",
       "      <td>0.492725</td>\n",
       "      <td>0.353787</td>\n",
       "      <td>0.594800</td>\n",
       "      <td>0.117404</td>\n",
       "      <td>-0.057516</td>\n",
       "      <td>0.421707</td>\n",
       "      <td>0.419516</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.246800</td>\n",
       "      <td>0.339823</td>\n",
       "      <td>0.481915</td>\n",
       "      <td>0.339823</td>\n",
       "      <td>0.582943</td>\n",
       "      <td>0.119884</td>\n",
       "      <td>-0.015773</td>\n",
       "      <td>0.426389</td>\n",
       "      <td>0.422389</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.215600</td>\n",
       "      <td>0.329840</td>\n",
       "      <td>0.473681</td>\n",
       "      <td>0.329840</td>\n",
       "      <td>0.574317</td>\n",
       "      <td>0.122365</td>\n",
       "      <td>0.014066</td>\n",
       "      <td>0.426549</td>\n",
       "      <td>0.421213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.240100</td>\n",
       "      <td>0.337735</td>\n",
       "      <td>0.481619</td>\n",
       "      <td>0.337735</td>\n",
       "      <td>0.581149</td>\n",
       "      <td>0.116164</td>\n",
       "      <td>-0.009532</td>\n",
       "      <td>0.426427</td>\n",
       "      <td>0.421953</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.247800</td>\n",
       "      <td>0.365920</td>\n",
       "      <td>0.502826</td>\n",
       "      <td>0.365920</td>\n",
       "      <td>0.604913</td>\n",
       "      <td>0.109136</td>\n",
       "      <td>-0.093782</td>\n",
       "      <td>0.422832</td>\n",
       "      <td>0.419641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.233100</td>\n",
       "      <td>0.330218</td>\n",
       "      <td>0.473415</td>\n",
       "      <td>0.330218</td>\n",
       "      <td>0.574646</td>\n",
       "      <td>0.120298</td>\n",
       "      <td>0.012937</td>\n",
       "      <td>0.424193</td>\n",
       "      <td>0.418948</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.245000</td>\n",
       "      <td>0.362129</td>\n",
       "      <td>0.499749</td>\n",
       "      <td>0.362129</td>\n",
       "      <td>0.601772</td>\n",
       "      <td>0.112443</td>\n",
       "      <td>-0.082450</td>\n",
       "      <td>0.419835</td>\n",
       "      <td>0.415749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.235900</td>\n",
       "      <td>0.338315</td>\n",
       "      <td>0.479848</td>\n",
       "      <td>0.338315</td>\n",
       "      <td>0.581648</td>\n",
       "      <td>0.122365</td>\n",
       "      <td>-0.011266</td>\n",
       "      <td>0.421867</td>\n",
       "      <td>0.416990</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.233900</td>\n",
       "      <td>0.375497</td>\n",
       "      <td>0.509828</td>\n",
       "      <td>0.375497</td>\n",
       "      <td>0.612778</td>\n",
       "      <td>0.107069</td>\n",
       "      <td>-0.122407</td>\n",
       "      <td>0.419347</td>\n",
       "      <td>0.415685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.247100</td>\n",
       "      <td>0.369126</td>\n",
       "      <td>0.504353</td>\n",
       "      <td>0.369127</td>\n",
       "      <td>0.607558</td>\n",
       "      <td>0.105829</td>\n",
       "      <td>-0.103366</td>\n",
       "      <td>0.419292</td>\n",
       "      <td>0.415582</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.221000</td>\n",
       "      <td>0.356401</td>\n",
       "      <td>0.495065</td>\n",
       "      <td>0.356401</td>\n",
       "      <td>0.596993</td>\n",
       "      <td>0.105829</td>\n",
       "      <td>-0.065329</td>\n",
       "      <td>0.420289</td>\n",
       "      <td>0.415208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.231200</td>\n",
       "      <td>0.339322</td>\n",
       "      <td>0.480659</td>\n",
       "      <td>0.339322</td>\n",
       "      <td>0.582514</td>\n",
       "      <td>0.121124</td>\n",
       "      <td>-0.014277</td>\n",
       "      <td>0.422148</td>\n",
       "      <td>0.416733</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.235600</td>\n",
       "      <td>0.360432</td>\n",
       "      <td>0.496095</td>\n",
       "      <td>0.360432</td>\n",
       "      <td>0.600360</td>\n",
       "      <td>0.109549</td>\n",
       "      <td>-0.077377</td>\n",
       "      <td>0.420680</td>\n",
       "      <td>0.416995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.222100</td>\n",
       "      <td>0.351586</td>\n",
       "      <td>0.489625</td>\n",
       "      <td>0.351586</td>\n",
       "      <td>0.592947</td>\n",
       "      <td>0.110790</td>\n",
       "      <td>-0.050935</td>\n",
       "      <td>0.421397</td>\n",
       "      <td>0.417311</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.242500</td>\n",
       "      <td>0.370960</td>\n",
       "      <td>0.504381</td>\n",
       "      <td>0.370960</td>\n",
       "      <td>0.609065</td>\n",
       "      <td>0.107896</td>\n",
       "      <td>-0.108848</td>\n",
       "      <td>0.420373</td>\n",
       "      <td>0.416966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.218200</td>\n",
       "      <td>0.361409</td>\n",
       "      <td>0.497412</td>\n",
       "      <td>0.361409</td>\n",
       "      <td>0.601173</td>\n",
       "      <td>0.105829</td>\n",
       "      <td>-0.080297</td>\n",
       "      <td>0.420879</td>\n",
       "      <td>0.416997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.242200</td>\n",
       "      <td>0.361680</td>\n",
       "      <td>0.497801</td>\n",
       "      <td>0.361680</td>\n",
       "      <td>0.601398</td>\n",
       "      <td>0.104589</td>\n",
       "      <td>-0.081106</td>\n",
       "      <td>0.420885</td>\n",
       "      <td>0.416968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.233800</td>\n",
       "      <td>0.362960</td>\n",
       "      <td>0.498749</td>\n",
       "      <td>0.362960</td>\n",
       "      <td>0.602462</td>\n",
       "      <td>0.104589</td>\n",
       "      <td>-0.084935</td>\n",
       "      <td>0.420851</td>\n",
       "      <td>0.416991</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.224900</td>\n",
       "      <td>0.363375</td>\n",
       "      <td>0.499059</td>\n",
       "      <td>0.363375</td>\n",
       "      <td>0.602806</td>\n",
       "      <td>0.106242</td>\n",
       "      <td>-0.086175</td>\n",
       "      <td>0.420804</td>\n",
       "      <td>0.416971</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6040, training_loss=0.27832004566066315, metrics={'train_runtime': 632.2454, 'train_samples_per_second': 152.662, 'train_steps_per_second': 9.553, 'total_flos': 10367130800160.0, 'train_loss': 0.27832004566066315, 'epoch': 10.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4e83df",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
