{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Matplotlib created a temporary cache directory at /tmp/matplotlib-d0344w_a because the default path (/teamspace/studios/this_studio/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd \n",
    "import os\n",
    "from pathlib import Path\n",
    "from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import torch.nn as nn\n",
    "import tqdm\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import AutoModel, AutoTokenizer\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "# from _models.huggingface.huggingface import * "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sample1</th>\n",
       "      <th>sample2</th>\n",
       "      <th>pos</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>and true become identical, so if I can convin...</td>\n",
       "      <td>supporting paragraphs, conclusion. The conclu...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ood and nbad are the number of nonspam and spa...</td>\n",
       "      <td>spam folder. And strangely enough, the better...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             sample1  \\\n",
       "0   and true become identical, so if I can convin...   \n",
       "1  ood and nbad are the number of nonspam and spa...   \n",
       "\n",
       "                                             sample2  pos  \n",
       "0   supporting paragraphs, conclusion. The conclu...    1  \n",
       "1   spam folder. And strangely enough, the better...    1  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DATA_DIR = \"embedding-text-structure/project/data\"\n",
    "\n",
    "# df = pd.read_csv(DATA_DIR + \"/questions.csv\")\n",
    "df = pd.read_pickle(DATA_DIR + \"/truncate_data_combined.pkl\")\n",
    "print(df.shape)\n",
    "df.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class TextPairDataset(Dataset):\n",
    "#     def __init__(self, data, tokenizer, max_length=128):\n",
    "#         self.data = data\n",
    "#         self.tokenizer = tokenizer\n",
    "#         self.max_length = max_length\n",
    "\n",
    "#     def __len__(self):\n",
    "#         return len(self.data)\n",
    "\n",
    "#     def __getitem__(self, idx):\n",
    "#         row = self.data.iloc[idx]\n",
    "#         sample1 = row['sample1']\n",
    "#         sample2 = row['sample2']\n",
    "#         label = row['pos']\n",
    "        \n",
    "#         # print(self.tokenizer)\n",
    "#         encoding1 = self.tokenizer(sample1, max_length=self.max_length, truncation=True, padding=True, return_tensors='pt',)\n",
    "#         encoding2 = self.tokenizer(sample2, max_length=self.max_length, truncation=True, padding=True, return_tensors='pt',)\n",
    "\n",
    "#         return {\n",
    "#             # 'input_ids1': encoding1['input_ids'].squeeze(),\n",
    "#             # 'attention_mask1': encoding1['attention_mask'].squeeze(),\n",
    "#             # 'input_ids2': encoding2['input_ids'].squeeze(),\n",
    "#             # 'attention_mask2': encoding2['attention_mask'].squeeze(),\n",
    "#             'encoded_input1': encoding1,\n",
    "#             'encoded_input2': encoding2,\n",
    "#             'label': torch.tensor(1.0 if label else -1.0, dtype=torch.float)  # CosineEmbeddingLoss expects 1 or -1\n",
    "#         }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class EmbeddingModel(nn.Module):\n",
    "#     def __init__(self, model_name):\n",
    "#         super(EmbeddingModel, self).__init__()\n",
    "#         # self.encoder = AutoModel.from_pretrained(model_name)\n",
    "#         self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "#         self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True)\n",
    "        \n",
    "#         # self.fc = nn.Linear(self.encoder.config.hidden_size, 128)  # Embedding size\n",
    "\n",
    "#     def forward(self, prompt):\n",
    "#         # outputs = self.encoder(input_ids, attention_mask=attention_mask)\n",
    "#         # pooled_output = outputs.last_hidden_state[:, 0]  # CLS token\n",
    "#         # embedding = self.fc(pooled_output)\n",
    "#         # return embedding\n",
    "#         self.encoder = self.encoder.to(device)\n",
    "\n",
    "#         encoded_input = self.tokenizer(\n",
    "#             prompt, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "#         ).to(device)\n",
    "#         # with torch.no_grad():\n",
    "#         model_output = self.encoder(**encoded_input)\n",
    "#         sentence_embeddings = model_output[0][:, 0]\n",
    "#         sentence_embeddings = torch.nn.functional.normalize(\n",
    "#             sentence_embeddings, p=2, dim=1\n",
    "#         )\n",
    "#         sentence_embeddings = sentence_embeddings[0]\n",
    "#         return sentence_embeddings.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, tokenizer, data, criterion, optimizer, device, batch_size=4):\n",
    "    data = data.sample(frac=1, random_state=seed)\n",
    "    model.train()\n",
    "    # print(model)\n",
    "    total_loss = 0.0\n",
    "    num_batches = max(int(len(data) / batch_size + 0.99), 1)\n",
    "    embeddings = []\n",
    "    for i in tqdm(range(num_batches), disable=not True):\n",
    "        batch = data[i * batch_size : (i + 1) * batch_size]\n",
    "        # input_ids1 = batch['input_ids1'].to(device)\n",
    "        # attention_mask1 = batch['attention_mask1'].to(device)\n",
    "        # input_ids2 = batch['input_ids2'].to(device)\n",
    "        # attention_mask2 = batch['attention_mask2'].to(device)\n",
    "        # labels = batch['label'].to(device)\n",
    "        # print(batch)\n",
    "        encoded_input1 = tokenizer(\n",
    "            list(batch['sample1']), padding=True, truncation=True, return_tensors=\"pt\"\n",
    "        ).to(device)\n",
    "        \n",
    "        # print(f\"index: {i}, {list(batch['sample2'])}\")\n",
    "        encoded_input2 = tokenizer(\n",
    "            list(batch['sample2']), padding=True, truncation=True, return_tensors=\"pt\"\n",
    "        ).to(device)\n",
    "        labels = torch.tensor(batch['pos'].values).to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        model_output1 = model(**encoded_input1)\n",
    "        model_output2 = model(**encoded_input2)\n",
    "\n",
    "        embeddings1 = model_output1[0][:, 0]\n",
    "        embeddings1 = torch.nn.functional.normalize(\n",
    "            embeddings1, p=2, dim=1\n",
    "        )\n",
    "        \n",
    "        embeddings2 = model_output2[0][:, 0]\n",
    "        embeddings2 = torch.nn.functional.normalize(\n",
    "            embeddings2, p=2, dim=1\n",
    "        )\n",
    "        loss = criterion(embeddings1, embeddings2, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        # Move tensors to CPU to free up GPU memory\n",
    "        del encoded_input1, encoded_input2, labels, model_output1, model_output2, embeddings1, embeddings2\n",
    "        torch.cuda.empty_cache()\n",
    "    \n",
    "    return total_loss / len(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  5%|▌         | 54/1000 [00:21<06:10,  2.55it/s]\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 392.00 MiB. GPU ",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[7], line 36\u001b[0m\n\u001b[1;32m     34\u001b[0m num_epochs \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m3\u001b[39m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[0;32m---> 36\u001b[0m     loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcriterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     37\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     39\u001b[0m \u001b[38;5;66;03m# main()\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[6], line 28\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m(model, tokenizer, data, criterion, optimizer, device, batch_size)\u001b[0m\n\u001b[1;32m     24\u001b[0m labels \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mtensor(batch[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpos\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mvalues)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m     26\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 28\u001b[0m model_output1 \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mencoded_input1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     29\u001b[0m model_output2 \u001b[38;5;241m=\u001b[39m model(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mencoded_input2)\n\u001b[1;32m     31\u001b[0m embeddings1 \u001b[38;5;241m=\u001b[39m model_output1[\u001b[38;5;241m0\u001b[39m][:, \u001b[38;5;241m0\u001b[39m]\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:830\u001b[0m, in \u001b[0;36mXLMRobertaModel.forward\u001b[0;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    821\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m    823\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(\n\u001b[1;32m    824\u001b[0m     input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m    825\u001b[0m     position_ids\u001b[38;5;241m=\u001b[39mposition_ids,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    828\u001b[0m     past_key_values_length\u001b[38;5;241m=\u001b[39mpast_key_values_length,\n\u001b[1;32m    829\u001b[0m )\n\u001b[0;32m--> 830\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    831\u001b[0m \u001b[43m    \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    832\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    833\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    834\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    835\u001b[0m \u001b[43m    \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    836\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    837\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    838\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    839\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    840\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    841\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    842\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    843\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:518\u001b[0m, in \u001b[0;36mXLMRobertaEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    507\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    508\u001b[0m         layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    509\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    515\u001b[0m         output_attentions,\n\u001b[1;32m    516\u001b[0m     )\n\u001b[1;32m    517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 518\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    519\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    520\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    521\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    522\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    523\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    524\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    525\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    526\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    528\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    529\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:407\u001b[0m, in \u001b[0;36mXLMRobertaLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    395\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m    396\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    397\u001b[0m     hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    404\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[torch\u001b[38;5;241m.\u001b[39mTensor]:\n\u001b[1;32m    405\u001b[0m     \u001b[38;5;66;03m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[39;00m\n\u001b[1;32m    406\u001b[0m     self_attn_past_key_value \u001b[38;5;241m=\u001b[39m past_key_value[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 407\u001b[0m     self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    408\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    409\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    410\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    411\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    412\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_past_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    413\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    414\u001b[0m     attention_output \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    416\u001b[0m     \u001b[38;5;66;03m# if decoder, the last output is tuple of self-attn cache\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:334\u001b[0m, in \u001b[0;36mXLMRobertaAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    324\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m    325\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m    326\u001b[0m     hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    332\u001b[0m     output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m    333\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[torch\u001b[38;5;241m.\u001b[39mTensor]:\n\u001b[0;32m--> 334\u001b[0m     self_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    335\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    336\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    337\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    338\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    339\u001b[0m \u001b[43m        \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    340\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    341\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    342\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    343\u001b[0m     attention_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(self_outputs[\u001b[38;5;241m0\u001b[39m], hidden_states)\n\u001b[1;32m    344\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m (attention_output,) \u001b[38;5;241m+\u001b[39m self_outputs[\u001b[38;5;241m1\u001b[39m:]  \u001b[38;5;66;03m# add attentions if we output them\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1539\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1540\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1544\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/xlm_roberta/modeling_xlm_roberta.py:230\u001b[0m, in \u001b[0;36mXLMRobertaSelfAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m    227\u001b[0m     past_key_value \u001b[38;5;241m=\u001b[39m (key_layer, value_layer)\n\u001b[1;32m    229\u001b[0m \u001b[38;5;66;03m# Take the dot product between \"query\" and \"key\" to get the raw attention scores.\u001b[39;00m\n\u001b[0;32m--> 230\u001b[0m attention_scores \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmatmul\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery_layer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    232\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mposition_embedding_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelative_key\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mposition_embedding_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrelative_key_query\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m    233\u001b[0m     query_length, key_length \u001b[38;5;241m=\u001b[39m query_layer\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m], key_layer\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m2\u001b[39m]\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 392.00 MiB. GPU "
     ]
    }
   ],
   "source": [
    "# def main():\n",
    "    # Load dataset\n",
    "data = pd.read_pickle(DATA_DIR + \"/truncate_data_combined.pkl\")\n",
    "\n",
    "# Load tokenizer\n",
    "# model_name = \"BAAI/bge-small-en-v1.5\"\n",
    "model_name = \"BAAI/bge-m3\"\n",
    "\n",
    "# Prepare dataset and dataloader\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "# dataset = TextPairDataset(data, tokenizer)\n",
    "# train_loader = DataLoader(dataset, batch_size=4, shuffle=True)\n",
    "\n",
    "# Prepare model\n",
    "if torch.backends.mps.is_available():\n",
    "    device = torch.device('mps')\n",
    "elif torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else: \n",
    "    device = torch.device(\"cpu\")\n",
    "# device = torch.device(\"cpu\")\n",
    "# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')\n",
    "model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(device)\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "# model = EmbeddingModel(model_name).to(device)\n",
    "assert model\n",
    "# raise\n",
    "\n",
    "# Define loss and optimizer\n",
    "criterion = nn.CosineEmbeddingLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=2e-5)\n",
    "\n",
    "# Train the model\n",
    "num_epochs = 3\n",
    "for epoch in range(num_epochs):\n",
    "    loss = train(model, tokenizer, data, criterion, optimizer, device, batch_size=1)\n",
    "    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}')\n",
    "\n",
    "# main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('./models/tokenizer_config.json',\n",
       " './models/special_tokens_map.json',\n",
       " './models/vocab.txt',\n",
       " './models/added_tokens.json',\n",
       " './models/tokenizer.json')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "name = 'bge-m3-finetune-combined'\n",
    "save_directory = f\"./models/{name}\"\n",
    "model.save_pretrained(save_directory)\n",
    "tokenizer.save_pretrained(save_directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_huggingface_embeddings_batched(\n",
    "#     prompts,\n",
    "#     batch_size=4,\n",
    "#     model_name=\"BAAI/bge-small-en-v1.5\",\n",
    "#     pbar=True,\n",
    "#     use_gpu=True,\n",
    "# ):\n",
    "#     device = torch.device(\"cuda\" if torch.cuda.is_available() and use_gpu else \"cpu\")\n",
    "#     model, tokenizer = load_model(model_name, device)\n",
    "\n",
    "#     num_batches = max(int(len(prompts) / batch_size + 0.99), 1)\n",
    "#     embeddings = []\n",
    "\n",
    "#     for i in tqdm(range(num_batches), disable=not pbar):\n",
    "#         batch = prompts[i * batch_size : (i + 1) * batch_size]\n",
    "#         encoded_input = tokenizer(\n",
    "#             batch, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "#         ).to(device)\n",
    "#         with torch.no_grad():\n",
    "#             model_output = model(**encoded_input)\n",
    "#             sentence_embeddings = model_output[0][:, 0]\n",
    "#             sentence_embeddings = torch.nn.functional.normalize(\n",
    "#                 sentence_embeddings, p=2, dim=1\n",
    "#             )\n",
    "#             embeddings.append(\n",
    "#                 sentence_embeddings.cpu()\n",
    "#             )  # Move embeddings to CPU after computation\n",
    "\n",
    "#         # Optional: Clear the CUDA cache after processing each batch\n",
    "#         torch.cuda.empty_cache()\n",
    "\n",
    "#     return torch.vstack(embeddings).detach().numpy()\n",
    "\n",
    "\n",
    "# def get_huggingface_embedding(prompt: str, model=\"BAAI/bge-small-en-v1.5\"):\n",
    "#     model, tokenizer = load_model(model)\n",
    "#     model = model.to(device)\n",
    "\n",
    "#     encoded_input = tokenizer(\n",
    "#         prompt, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "#     ).to(device)\n",
    "#     with torch.no_grad():\n",
    "#         model_output = model(**encoded_input)\n",
    "#         sentence_embeddings = model_output[0][:, 0]\n",
    "#         sentence_embeddings = torch.nn.functional.normalize(\n",
    "#             sentence_embeddings, p=2, dim=1\n",
    "#         )\n",
    "#         sentence_embeddings = sentence_embeddings[0]\n",
    "#     return sentence_embeddings.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "research",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
