{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('RobZamp/sick')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c0d83c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n",
    "from huggingface_hub import login\n",
    "\n",
    "# Log in using your Hugging Face token\n",
    "login(\"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\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8b6e00bb67843788ccbcca8ebf21bc5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name ='meta-llama/Llama-2-7b-hf'\n",
    "#config.num_labels=2\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "#tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = CLMSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 4439\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 495\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence-1:: {data_point['sentence_A']}. # Sentence-2: {data_point['sentence_B']} # Output: The similarity is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['relatedness_score'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'input']\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<s> # Sentence-1:: A person on a black motorbike is doing tricks with a jacket. # Sentence-2: A person is riding the bicycle on one wheel # Output: The similarity is'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=5e-4,\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    gradient_accumulation_steps=10,\n",
    "    num_train_epochs=8,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=500,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3545' max='3544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3544/3544 1:29:36, Epoch 7.98/8]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.805200</td>\n",
       "      <td>0.385745</td>\n",
       "      <td>0.504549</td>\n",
       "      <td>0.385745</td>\n",
       "      <td>0.621084</td>\n",
       "      <td>0.103030</td>\n",
       "      <td>0.619425</td>\n",
       "      <td>0.899418</td>\n",
       "      <td>0.871772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.280000</td>\n",
       "      <td>0.181458</td>\n",
       "      <td>0.323182</td>\n",
       "      <td>0.181458</td>\n",
       "      <td>0.425978</td>\n",
       "      <td>0.228283</td>\n",
       "      <td>0.820974</td>\n",
       "      <td>0.916636</td>\n",
       "      <td>0.884169</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.182600</td>\n",
       "      <td>0.206960</td>\n",
       "      <td>0.353360</td>\n",
       "      <td>0.206960</td>\n",
       "      <td>0.454929</td>\n",
       "      <td>0.167677</td>\n",
       "      <td>0.795813</td>\n",
       "      <td>0.918255</td>\n",
       "      <td>0.889152</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.129900</td>\n",
       "      <td>0.169919</td>\n",
       "      <td>0.303626</td>\n",
       "      <td>0.169919</td>\n",
       "      <td>0.412213</td>\n",
       "      <td>0.256566</td>\n",
       "      <td>0.832358</td>\n",
       "      <td>0.923134</td>\n",
       "      <td>0.896392</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.096900</td>\n",
       "      <td>0.158337</td>\n",
       "      <td>0.305653</td>\n",
       "      <td>0.158337</td>\n",
       "      <td>0.397916</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.843785</td>\n",
       "      <td>0.922958</td>\n",
       "      <td>0.890008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.066500</td>\n",
       "      <td>0.153640</td>\n",
       "      <td>0.292694</td>\n",
       "      <td>0.153640</td>\n",
       "      <td>0.391969</td>\n",
       "      <td>0.238384</td>\n",
       "      <td>0.848419</td>\n",
       "      <td>0.922970</td>\n",
       "      <td>0.892428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.052900</td>\n",
       "      <td>0.159080</td>\n",
       "      <td>0.294591</td>\n",
       "      <td>0.159080</td>\n",
       "      <td>0.398848</td>\n",
       "      <td>0.244444</td>\n",
       "      <td>0.843052</td>\n",
       "      <td>0.922281</td>\n",
       "      <td>0.890584</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "SafetensorError",
     "evalue": "Error while serializing: IoError(Os { code: 7, kind: ArgumentListTooLong, message: \"Argument list too long\" })",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mSafetensorError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:2241\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m   2239\u001b[0m         hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m   2240\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2241\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2242\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2243\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2244\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2245\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2246\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:2612\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m   2610\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mepoch \u001b[38;5;241m=\u001b[39m epoch \u001b[38;5;241m+\u001b[39m (step \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m steps_skipped) \u001b[38;5;241m/\u001b[39m steps_in_epoch\n\u001b[1;32m   2611\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2612\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maybe_log_save_evaluate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2613\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtr_loss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_norm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_time\u001b[49m\n\u001b[1;32m   2614\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2615\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2616\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_substep_end(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:3092\u001b[0m, in \u001b[0;36mTrainer._maybe_log_save_evaluate\u001b[0;34m(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)\u001b[0m\n\u001b[1;32m   3089\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save \u001b[38;5;241m=\u001b[39m is_new_best_metric\n\u001b[1;32m   3091\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 3092\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3093\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_save(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:3190\u001b[0m, in \u001b[0;36mTrainer._save_checkpoint\u001b[0;34m(self, model, trial)\u001b[0m\n\u001b[1;32m   3188\u001b[0m run_dir \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_output_dir(trial\u001b[38;5;241m=\u001b[39mtrial)\n\u001b[1;32m   3189\u001b[0m output_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(run_dir, checkpoint_folder)\n\u001b[0;32m-> 3190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_internal_call\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m   3192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39msave_only_model:\n\u001b[1;32m   3193\u001b[0m     \u001b[38;5;66;03m# Save optimizer and scheduler\u001b[39;00m\n\u001b[1;32m   3194\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_save_optimizer_and_scheduler(output_dir)\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:3860\u001b[0m, in \u001b[0;36mTrainer.save_model\u001b[0;34m(self, output_dir, _internal_call)\u001b[0m\n\u001b[1;32m   3857\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped\u001b[38;5;241m.\u001b[39msave_checkpoint(output_dir)\n\u001b[1;32m   3859\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mshould_save:\n\u001b[0;32m-> 3860\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3862\u001b[0m \u001b[38;5;66;03m# Push to the Hub when `save_model` is called by the user.\u001b[39;00m\n\u001b[1;32m   3863\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpush_to_hub \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _internal_call:\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/trainer.py:3964\u001b[0m, in \u001b[0;36mTrainer._save\u001b[0;34m(self, output_dir, state_dict)\u001b[0m\n\u001b[1;32m   3962\u001b[0m             torch\u001b[38;5;241m.\u001b[39msave(state_dict, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(output_dir, WEIGHTS_NAME))\n\u001b[1;32m   3963\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 3964\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3965\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_dir\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msafe_serialization\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_safetensors\u001b[49m\n\u001b[1;32m   3966\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3968\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprocessing_class \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   3969\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprocessing_class\u001b[38;5;241m.\u001b[39msave_pretrained(output_dir)\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/transformers/modeling_utils.py:3032\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m   3027\u001b[0m     gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m   3029\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m   3030\u001b[0m     \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m   3031\u001b[0m     \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 3032\u001b[0m     \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3033\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3034\u001b[0m     save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
      "File \u001b[0;32m~/.conda/envs/torch_env/lib/python3.10/site-packages/safetensors/torch.py:286\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m    256\u001b[0m     tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m    257\u001b[0m     filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m    258\u001b[0m     metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    259\u001b[0m ):\n\u001b[1;32m    260\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    261\u001b[0m \u001b[38;5;124;03m    Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m    262\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    284\u001b[0m \u001b[38;5;124;03m    ```\u001b[39;00m\n\u001b[1;32m    285\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 286\u001b[0m     \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mSafetensorError\u001b[0m: Error while serializing: IoError(Os { code: 7, kind: ArgumentListTooLong, message: \"Argument list too long\" })"
     ]
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d3bbea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4df3e58e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
