{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "49275bae-af33-47c2-9b7d-d45737bbc6a4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments\n",
    "from datasets import load_dataset\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "af28d239-92f2-4560-b7ab-a3d129904fa6",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def tokenize_sst2(example):\n",
    "    return tokenizer(example[\"sentence\"], truncation=True)\n",
    "\n",
    "def tokenize_agnews(example):\n",
    "    return tokenizer(example[\"text\"], truncation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae7aa9ad-60d4-47d8-b3dc-c22fd88e39c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def shuffle_weights(model, N):\n",
    "    \"\"\"Randomly permute N% of the weights in `model`.\n",
    "\n",
    "    This is a fast approximation of re-initializing the weights of a model.\n",
    "\n",
    "    Assumes weights are distributed independently of the dimensions of the weight tensors\n",
    "      (i.e., the weights have the same distribution along each dimension).\n",
    "\n",
    "    :param Model model: Modify the weights of the given model.\n",
    "    \"\"\"\n",
    "    names, weights = get_weights(model)\n",
    "\n",
    "    perm_weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]\n",
    "    # Faster, but less random: only permutes along the first dimension\n",
    "    # weights = [np.random.permutation(w) for w in weights]\n",
    "    set_weights(model, names, perm_weights, weights, N)\n",
    "    \n",
    "def get_weights(model):\n",
    "    ws = []\n",
    "    names = []\n",
    "    for i in model.named_parameters():\n",
    "        name = i[0]\n",
    "        if 'weight' in name:\n",
    "            names.append(name)\n",
    "            ws.append(i[1].data.detach().cpu().numpy())\n",
    "    return names, ws\n",
    "\n",
    "def set_weights(model, names, perm_w, w, N):\n",
    "    model_dict = dict(model.named_parameters())\n",
    "    num_to_perm = int(len(w) * N)\n",
    "    print(num_to_perm)\n",
    "    perm_idx = np.random.choice(len(w), num_to_perm, replace=False)\n",
    "    for i, name in enumerate(names):\n",
    "        if i in perm_idx:\n",
    "            model_dict[name].data.copy_(torch.tensor(perm_w[i]).to('cuda:0'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82415e76-d0fa-4d87-af82-2ce3129d8327",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    \n",
    "    # Calculate accuracy\n",
    "    accuracy = accuracy_score(labels, preds)\n",
    "\n",
    "   # Calculate precision, recall, and F1-score\n",
    "    precision = precision_score(labels, preds, average='weighted')\n",
    "    recall = recall_score(labels, preds, average='weighted')\n",
    "    f1 = f1_score(labels, preds, average='weighted')\n",
    "    \n",
    "    return {\n",
    "        'accuracy': accuracy,\n",
    "        'precision': precision,\n",
    "        'recall': recall,\n",
    "        'f1': f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "344d8d90-5b39-44f3-ab46-c90d7edbbc49",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "# SST@"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "721c5f8d-4e51-4b62-8349-924d085b186d",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "# SST2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "8c0a88ef-6fd9-490b-8249-afb4f46d3a68",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-SST-2\")\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-SST-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35b6921e-6277-434a-ab22-78f1b4b09207",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "raw_sst2 = load_dataset('glue', 'sst2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94bbda39-a292-475f-87e0-81a258216fec",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenized_datasets = raw_sst2.map(tokenize_sst2, task='sst2', batched=True)\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "8e3a5cde-574d-4598-9a30-d450d5cf9a5e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\"sst2-finetuned-model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "75ea072a-efde-4e9c-be09-84acabc36914",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_sst2[\"train\"],\n",
    "    eval_dataset=raw_sst2[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8f044ee-e762-418a-9380-fe2b70afbb2b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"validation\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "cd97b877-36ba-4dab-b0a6-19bf3e642585",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5\n"
     ]
    }
   ],
   "source": [
    "shuffle_weights(model, 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "010599c6-e565-4828-93eb-23628cf7bd55",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.save_model(\"perm5perc_sst2_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7fde046-3489-40f5-acb1-8c9f76572589",
   "metadata": {
    "tags": [],
    "user_expressions": []
   },
   "source": [
    "### Evaluate permuted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "0b35d645-d8d4-4c0d-a2e0-536b676a7e23",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "perm_model = AutoModelForSequenceClassification.from_pretrained(\"perm5perc_sst2_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "37e94c75-e156-403e-8fc6-aae3c9481aae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=perm_model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_sst2[\"train\"],\n",
    "    eval_dataset=raw_sst2[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "07b3e357-8b4c-43b4-b750-b886932b47a2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='109' max='109' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [109/109 00:02]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.6473379731178284, 'eval_accuracy': 0.6089449541284404, 'eval_precision': 0.705911057072326, 'eval_recall': 0.6089449541284404, 'eval_f1': 0.5515552067785796, 'eval_runtime': 2.3478, 'eval_samples_per_second': 371.411, 'eval_steps_per_second': 46.426}\n"
     ]
    }
   ],
   "source": [
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"validation\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "696827c5-d002-4ae0-8d87-5164d6da1a28",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3c8817b1-22d4-42cb-88de-ba03389b1abc",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "# ag-news"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "id": "ab135974-4e02-45e2-aeaf-890940a56221",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-ag-news\")\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-ag-news\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d5f40ad-d809-4f40-8301-ef9490f2b526",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "raw_agnews = load_dataset('ag_news')\n",
    "\n",
    "tokenized_datasets = raw_agnews.map(tokenize_agnews, batched=True)\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "e45ccbe2-016f-4653-a112-da5b79ad585a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='950' max='950' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [950/950 00:37]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.839389443397522, 'eval_accuracy': 0.6667105263157894, 'eval_precision': 0.7894484623269422, 'eval_recall': 0.6667105263157894, 'eval_f1': 0.61325119110938, 'eval_runtime': 37.2097, 'eval_samples_per_second': 204.248, 'eval_steps_per_second': 25.531}\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\"agnews-finetuned-model\")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_agnews[\"train\"],\n",
    "    eval_dataset=raw_agnews[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"test\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "id": "1597471b-a73d-4ab8-87b2-41373d846dd2",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "shuffle_weights(model, 0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "4263417a-e436-4d4e-81a1-397b84c5b5c1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.save_model(\"perm10perc_agnews_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb485a2a-c925-4567-a36b-7158c92cecbd",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "### Evaluate permuted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "f20fa99a-c4fd-4ddb-a928-15732d02acc8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "perm_model = AutoModelForSequenceClassification.from_pretrained(\"perm10perc_agnews_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "3964f476-2572-452f-8a1a-a558ff65df1a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=perm_model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_agnews[\"train\"],\n",
    "    eval_dataset=raw_agnews[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "id": "1a339cec-d4a7-4ba1-a589-8dc4edab8974",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='950' max='950' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [950/950 00:36]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.839389443397522, 'eval_accuracy': 0.6667105263157894, 'eval_precision': 0.7894484623269422, 'eval_recall': 0.6667105263157894, 'eval_f1': 0.61325119110938, 'eval_runtime': 36.4451, 'eval_samples_per_second': 208.533, 'eval_steps_per_second': 26.067}\n"
     ]
    }
   ],
   "source": [
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"test\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49cbbf5-0ee4-425c-8d87-25adf863d325",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bedb2e6d-35ca-412a-85ad-d1da61a5bfa7",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "# CoLA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "23ef2302-a8a7-4cd3-b40c-7aa14a92f675",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b736bc659f2414480cbc019b6d71259",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)okenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29d9204e9e2148a3ae1edcb773ccbbaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/476 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5525b2ab6b024c8e8550c40f3f6b8942",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10e78b414c04471e99186cde94e6cb9f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ba17a4fe8e14d8c9fb3ed78f57d03cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-cola\")\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-cola\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63eabd5-b2cc-4ec6-9aa6-a44c04021bc7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "raw_cola = load_dataset('glue', 'cola')\n",
    "\n",
    "tokenized_datasets = raw_cola.map(tokenize_sst2, batched=True)\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "c95b0278-556f-4b5c-b29a-ba1e0a0a879b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='131' max='131' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [131/131 00:01]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1334: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.6283413767814636, 'eval_accuracy': 0.6912751677852349, 'eval_precision': 0.47786135759650467, 'eval_recall': 0.6912751677852349, 'eval_f1': 0.5650900181101524, 'eval_runtime': 1.4828, 'eval_samples_per_second': 703.416, 'eval_steps_per_second': 88.349}\n"
     ]
    }
   ],
   "source": [
    "training_args = TrainingArguments(\"cola-finetuned-model\")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_cola[\"train\"],\n",
    "    eval_dataset=raw_cola[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")\n",
    "\n",
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"validation\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "a08c460c-811e-4b70-991a-3f90c9520f18",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "shuffle_weights(model, 0.10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "8a03a295-eee7-48ae-8129-f45859f7b567",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer.save_model(\"perm10perc_cola_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cbfe42e-2bd5-4d53-96f8-b047febd768e",
   "metadata": {
    "user_expressions": []
   },
   "source": [
    "### Evaluate permuted model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "83fb5b07-589b-40e3-b84e-411440eb8f9a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "perm_model = AutoModelForSequenceClassification.from_pretrained(\"perm10perc_cola_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "4b81c6b1-c3b2-4d28-bedf-d893dbcc4bb5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=perm_model,\n",
    "    args=training_args,\n",
    "    train_dataset=raw_cola[\"train\"],\n",
    "    eval_dataset=raw_cola[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,  # Define your custom metrics function\n",
    "    tokenizer=tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "f61c15b0-8625-4672-a660-ff3de921f8df",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='131' max='131' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [131/131 00:01]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.6283413767814636, 'eval_accuracy': 0.6912751677852349, 'eval_precision': 0.47786135759650467, 'eval_recall': 0.6912751677852349, 'eval_f1': 0.5650900181101524, 'eval_runtime': 1.4766, 'eval_samples_per_second': 706.371, 'eval_steps_per_second': 88.72}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/linux/mambaforge-3.11/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1334: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "evaluation_results = trainer.evaluate(tokenized_datasets[\"validation\"])\n",
    "print(evaluation_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc11cd8-561c-46f3-b946-44eb8c04118b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
