{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 2470\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filter: 100%|██████████| 2470/2470 [00:00<00:00, 59369.57 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 2466\n",
      "})\n",
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 1972\n",
      "})\n",
      "Test Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity'],\n",
      "    num_rows: 494\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "cols = ['id', 'text', 'label', 'intensity']\n",
    "path = \"https://raw.githubusercontent.com/vinayakumarr/WASSA-2017/refs/heads/master/wassa/data/training/\"\n",
    "anger_train = pd.read_csv(StringIO(requests.get(path + 'anger-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "fear_train = pd.read_csv(StringIO(requests.get(path + 'fear-ratings-0to1.train').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "sad_train = pd.read_csv(StringIO(requests.get(path + 'sadness-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "joy_train = pd.read_csv(StringIO(requests.get(path + 'joy-ratings-0to1.train.txt').text), header=None, sep='\\t', names=cols, index_col=0)\n",
    "\n",
    "dataset = pd.concat([anger_train, fear_train, sad_train, joy_train], axis=0)\n",
    "\n",
    "# Reset index for the combined DataFrame (optional)\n",
    "dataset.reset_index(inplace=True)\n",
    "\n",
    "from datasets import Dataset\n",
    "import pandas as pd\n",
    "dataset = Dataset.from_pandas(dataset)\n",
    "\n",
    "\n",
    "# Shuffle the dataset\n",
    "dataset = dataset.shuffle(seed=42)\n",
    "\n",
    "# Inspect the dataset\n",
    "print(dataset)\n",
    "\n",
    "def is_valid_intensity(example):\n",
    "    if example['intensity'] is not None:\n",
    "        #print(example['intensity'])\n",
    "        try: \n",
    "            k = float(example['intensity'])\n",
    "            return True\n",
    "        except:\n",
    "        \n",
    "            return False\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "# Filter the dataset\n",
    "dataset = dataset.filter(is_valid_intensity)\n",
    "print(dataset)\n",
    "# Split the shuffled dataset into train and test sets\n",
    "train_test_split = dataset.train_test_split(test_size=0.2, seed=42)\n",
    "\n",
    "# Access the train and test datasets\n",
    "train_data = train_test_split['train']\n",
    "val_data = train_test_split['test']\n",
    "\n",
    "# Inspect the datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Test Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1972/1972 [00:00<00:00, 13807.46 examples/s]\n",
      "Map: 100%|██████████| 494/494 [00:00<00:00, 13342.95 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity', 'labels', 'input'],\n",
      "    num_rows: 1972\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'text', 'label', 'intensity', 'labels', 'input'],\n",
      "    num_rows: 494\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Input: {data_point['text']} # Label: {data_point['label']} # Output: The intensity is{mask_token}\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['intensity'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "train_data = train_data.map(add_label_column)\n",
    "val_data = val_data.map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 1972/1972 [00:00<00:00, 28771.58 examples/s]\n",
      "Map: 100%|██████████| 494/494 [00:00<00:00, 26341.37 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete = ['label', 'intensity','id', 'text']  # Update as per your dataset\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] # Input: @TehShockwave turn that grumpy frown upside-down\\\\n\\\\nYou did something next to impossible today # Label: sadness # Output: The intensity is[MASK][SEP]'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'text', 'label', 'intensity', 'labels', 'input'],\n",
       "    num_rows: 494\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "75"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForMaskedLM were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "from modeling import MLMSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "\n",
    "model = MLMSequenceClassification.from_pretrained(model_name, config=config, num_labels=1, mask_token_id=tokenizer.mask_token_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "159b238b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLMSequenceClassification(\n",
       "  (transformer): DebertaV2ForMaskedLM(\n",
       "    (deberta): DebertaV2Model(\n",
       "      (embeddings): DebertaV2Embeddings(\n",
       "        (word_embeddings): Embedding(128100, 768, padding_idx=0)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (encoder): DebertaV2Encoder(\n",
       "        (layer): ModuleList(\n",
       "          (0-11): 12 x DebertaV2Layer(\n",
       "            (attention): DebertaV2Attention(\n",
       "              (self): DisentangledSelfAttention(\n",
       "                (query_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (key_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (value_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (pos_dropout): Dropout(p=0.1, inplace=False)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "              (output): DebertaV2SelfOutput(\n",
       "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "                (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "                (dropout): Dropout(p=0.1, inplace=False)\n",
       "              )\n",
       "            )\n",
       "            (intermediate): DebertaV2Intermediate(\n",
       "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "              (intermediate_act_fn): GELUActivation()\n",
       "            )\n",
       "            (output): DebertaV2Output(\n",
       "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "        )\n",
       "        (rel_embeddings): Embedding(512, 768)\n",
       "        (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (cls): LegacyDebertaV2OnlyMLMHead(\n",
       "      (predictions): LegacyDebertaV2LMPredictionHead(\n",
       "        (transform): LegacyDebertaV2PredictionHeadTransform(\n",
       "          (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "          (transform_act_fn): GELUActivation()\n",
       "          (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "        )\n",
       "        (decoder): Linear(in_features=768, out_features=128100, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (classifier): ClassificationHead(\n",
       "    (dense): Linear(in_features=128100, out_features=768, bias=True)\n",
       "    (dropout): Dropout(p=0.02, inplace=False)\n",
       "    (out_proj): Linear(in_features=768, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1240' max='1240' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1240/1240 02:51, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.027400</td>\n",
       "      <td>0.019679</td>\n",
       "      <td>0.112409</td>\n",
       "      <td>0.019679</td>\n",
       "      <td>0.140281</td>\n",
       "      <td>0.544534</td>\n",
       "      <td>0.447253</td>\n",
       "      <td>0.781615</td>\n",
       "      <td>0.782923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.029500</td>\n",
       "      <td>0.021654</td>\n",
       "      <td>0.119358</td>\n",
       "      <td>0.021654</td>\n",
       "      <td>0.147154</td>\n",
       "      <td>0.483806</td>\n",
       "      <td>0.391764</td>\n",
       "      <td>0.784642</td>\n",
       "      <td>0.787047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.026100</td>\n",
       "      <td>0.021107</td>\n",
       "      <td>0.115012</td>\n",
       "      <td>0.021107</td>\n",
       "      <td>0.145284</td>\n",
       "      <td>0.526316</td>\n",
       "      <td>0.407128</td>\n",
       "      <td>0.789094</td>\n",
       "      <td>0.802312</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.024300</td>\n",
       "      <td>0.019044</td>\n",
       "      <td>0.109842</td>\n",
       "      <td>0.019044</td>\n",
       "      <td>0.137999</td>\n",
       "      <td>0.528340</td>\n",
       "      <td>0.465093</td>\n",
       "      <td>0.782379</td>\n",
       "      <td>0.802609</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.023700</td>\n",
       "      <td>0.021267</td>\n",
       "      <td>0.117478</td>\n",
       "      <td>0.021267</td>\n",
       "      <td>0.145834</td>\n",
       "      <td>0.497976</td>\n",
       "      <td>0.402631</td>\n",
       "      <td>0.813056</td>\n",
       "      <td>0.814139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.021100</td>\n",
       "      <td>0.013921</td>\n",
       "      <td>0.091325</td>\n",
       "      <td>0.013921</td>\n",
       "      <td>0.117988</td>\n",
       "      <td>0.613360</td>\n",
       "      <td>0.608976</td>\n",
       "      <td>0.806349</td>\n",
       "      <td>0.809668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.020100</td>\n",
       "      <td>0.027137</td>\n",
       "      <td>0.131891</td>\n",
       "      <td>0.027137</td>\n",
       "      <td>0.164733</td>\n",
       "      <td>0.451417</td>\n",
       "      <td>0.237764</td>\n",
       "      <td>0.808880</td>\n",
       "      <td>0.809335</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.019200</td>\n",
       "      <td>0.019436</td>\n",
       "      <td>0.110341</td>\n",
       "      <td>0.019436</td>\n",
       "      <td>0.139414</td>\n",
       "      <td>0.528340</td>\n",
       "      <td>0.454064</td>\n",
       "      <td>0.816020</td>\n",
       "      <td>0.818913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.018100</td>\n",
       "      <td>0.018404</td>\n",
       "      <td>0.106294</td>\n",
       "      <td>0.018404</td>\n",
       "      <td>0.135662</td>\n",
       "      <td>0.528340</td>\n",
       "      <td>0.483059</td>\n",
       "      <td>0.817538</td>\n",
       "      <td>0.818047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.018800</td>\n",
       "      <td>0.016642</td>\n",
       "      <td>0.099727</td>\n",
       "      <td>0.016642</td>\n",
       "      <td>0.129003</td>\n",
       "      <td>0.587045</td>\n",
       "      <td>0.532562</td>\n",
       "      <td>0.820285</td>\n",
       "      <td>0.817922</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.018500</td>\n",
       "      <td>0.015061</td>\n",
       "      <td>0.094765</td>\n",
       "      <td>0.015061</td>\n",
       "      <td>0.122725</td>\n",
       "      <td>0.619433</td>\n",
       "      <td>0.576950</td>\n",
       "      <td>0.820160</td>\n",
       "      <td>0.817953</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.017900</td>\n",
       "      <td>0.017392</td>\n",
       "      <td>0.102593</td>\n",
       "      <td>0.017392</td>\n",
       "      <td>0.131880</td>\n",
       "      <td>0.570850</td>\n",
       "      <td>0.511477</td>\n",
       "      <td>0.819751</td>\n",
       "      <td>0.818542</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1240, training_loss=0.022004768588850573, metrics={'train_runtime': 171.9378, 'train_samples_per_second': 114.693, 'train_steps_per_second': 7.212, 'total_flos': 6896683617024.0, 'train_loss': 0.022004768588850573, 'epoch': 10.0})"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54c97e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eac787d",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Column to remove ['corpus', 'complexity', 'token', 'sentence'] not in the dataset. Current columns in the dataset: ['id', 'text', 'label', 'intensity', 'labels', 'input']",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 26\u001b[0m\n\u001b[1;32m     22\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mpreprocessing_function\u001b[39m(examples):\n\u001b[1;32m     24\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m tokenizer(examples[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m'\u001b[39m], truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m510\u001b[39m)\n\u001b[0;32m---> 26\u001b[0m tokenized_train_data1 \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_data\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreprocessing_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcol_to_delete\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     28\u001b[0m data_collator \u001b[38;5;241m=\u001b[39m DataCollatorWithPadding(tokenizer\u001b[38;5;241m=\u001b[39mtokenizer, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m     32\u001b[0m dataloader \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m     33\u001b[0m     tokenized_train_data1,\n\u001b[1;32m     34\u001b[0m     batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[1;32m     35\u001b[0m     collate_fn\u001b[38;5;241m=\u001b[39mdata_collator,\n\u001b[1;32m     36\u001b[0m     shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m     37\u001b[0m )\n",
      "File \u001b[0;32m~/anaconda3/envs/MD/lib/python3.10/site-packages/datasets/arrow_dataset.py:562\u001b[0m, in \u001b[0;36mtransmit_format.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    555\u001b[0m self_format \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m    556\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_type,\n\u001b[1;32m    557\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mformat_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_kwargs,\n\u001b[1;32m    558\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_columns,\n\u001b[1;32m    559\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_all_columns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_output_all_columns,\n\u001b[1;32m    560\u001b[0m }\n\u001b[1;32m    561\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 562\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    563\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m    564\u001b[0m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/MD/lib/python3.10/site-packages/datasets/arrow_dataset.py:3000\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m   2998\u001b[0m     missing_columns \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m(remove_columns) \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mset\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data\u001b[38;5;241m.\u001b[39mcolumn_names)\n\u001b[1;32m   2999\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m missing_columns:\n\u001b[0;32m-> 3000\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m   3001\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mColumn to remove \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(missing_columns)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in the dataset. Current columns in the dataset: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data\u001b[38;5;241m.\u001b[39mcolumn_names\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   3002\u001b[0m         )\n\u001b[1;32m   3004\u001b[0m load_from_cache_file \u001b[38;5;241m=\u001b[39m load_from_cache_file \u001b[38;5;28;01mif\u001b[39;00m load_from_cache_file \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_caching_enabled()\n\u001b[1;32m   3006\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fn_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mValueError\u001b[0m: Column to remove ['corpus', 'complexity', 'token', 'sentence'] not in the dataset. Current columns in the dataset: ['id', 'text', 'label', 'intensity', 'labels', 'input']"
     ]
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f6524c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7232])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dade7e4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fea88e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 768])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z_layer_outputs[0][-1].mean(dim=1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5c6aaa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4ebf09",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81afac6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
