{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset(\"glue\", 'stsb')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"answerdotai/ModernBERT-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 5749/5749 [00:00<00:00, 20683.09 examples/s]\n",
      "Map: 100%|██████████| 1500/1500 [00:00<00:00, 20370.72 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['sentence1', 'sentence2', 'label', 'idx', 'labels', 'input'],\n",
      "    num_rows: 5749\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['sentence1', 'sentence2', 'label', 'idx', 'labels', 'input'],\n",
      "    num_rows: 1500\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence-1:: {data_point['sentence1']}. # Sentence-2: {data_point['sentence2']} # Output: The similarity is{mask_token}\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['label'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 5749/5749 [00:00<00:00, 35110.51 examples/s]\n",
      "Map: 100%|██████████| 1500/1500 [00:00<00:00, 32565.65 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['sentence1', 'sentence2', 'label', 'idx', 'input']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS]# Sentence-1:: The man hit the other man with a stick.. # Sentence-2: The man spanked the other man with a stick. # Output: The similarity is[MASK][SEP]'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['sentence1', 'sentence2', 'label', 'idx', 'labels', 'input'],\n",
       "    num_rows: 1500\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "145"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "from modeling import MLMSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "\n",
    "model = MLMSequenceClassification.from_pretrained(model_name, config=config, num_labels=1, mask_token_id=tokenizer.mask_token_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "159b238b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLMSequenceClassification(\n",
       "  (transformer): ModernBertForMaskedLM(\n",
       "    (model): ModernBertModel(\n",
       "      (embeddings): ModernBertEmbeddings(\n",
       "        (tok_embeddings): Embedding(50368, 768, padding_idx=50283)\n",
       "        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (drop): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (layers): ModuleList(\n",
       "        (0): ModernBertEncoderLayer(\n",
       "          (attn_norm): Identity()\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (1-2): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (3): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (4-5): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (6): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (7-8): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (9): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (10-11): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (12): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (13-14): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (15): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (16-17): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (18): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (19-20): 2 x ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=10000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "        (21): ModernBertEncoderLayer(\n",
       "          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): ModernBertAttention(\n",
       "            (Wqkv): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (rotary_emb): ModernBertUnpaddedRotaryEmbedding(dim=64, base=160000.0, scale_base=None)\n",
       "            (Wo): Linear(in_features=768, out_features=768, bias=False)\n",
       "            (out_drop): Identity()\n",
       "          )\n",
       "          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): ModernBertMLP(\n",
       "            (Wi): Linear(in_features=768, out_features=2304, bias=False)\n",
       "            (act): GELUActivation()\n",
       "            (drop): Dropout(p=0.0, inplace=False)\n",
       "            (Wo): Linear(in_features=1152, out_features=768, bias=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (head): ModernBertPredictionHead(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=False)\n",
       "      (act): GELUActivation()\n",
       "      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (decoder): Linear(in_features=768, out_features=50368, bias=True)\n",
       "  )\n",
       "  (classifier): ClassificationHead(\n",
       "    (dense): Linear(in_features=50368, out_features=768, bias=True)\n",
       "    (dropout): Dropout(p=0.02, inplace=False)\n",
       "    (out_proj): Linear(in_features=768, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-27 23:23:18,737] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libpthread.so.0, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libstdc++.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libm.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for bool@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_logic_error(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::logic_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::~locale()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_end_catch@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::logic_error::~logic_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__si_class_type_info@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new[](unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak_hard()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::basic_streambuf(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned short@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::resize(unsigned long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char const*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ctype<char>::_M_widen_init() const@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_invalid_argument(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::operator=(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_free_exception@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::notify_one()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::~Init()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::~basic_string()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_pure_virtual@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::flush()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__class_type_info@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_rethrow@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_fstream<char, std::char_traits<char> >::~basic_fstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::compare(char const*) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::system_clock::now()@GLIBCXX_3.4.19'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Hash_bytes(void const*, unsigned long, unsigned long)@CXXABI_1.3.5'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long long>(long long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const@GLIBCXX_3.4.18'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::out_of_range@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long>(unsigned long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::~ios_base()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::range_error::~range_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::~__basic_file()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_acquire@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<bool>(bool)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::overflow_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::range_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_filebuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete[](void*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(unsigned long, char, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_transfer(std::__detail::_List_node_base*, std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::replace(unsigned long, unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::exception@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_M_destroy(std::allocator<wchar_t> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream& std::istream::_M_extract<double>(double&)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::basic_ifstream(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new(unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned int@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::domain_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char, unsigned long) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::put(char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for int@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_alloc()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_thread_atexit@CXXABI_1.3.7'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::~basic_ifstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::Init()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::condition_variable()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::basic_filebuf()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::domain_error::~domain_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cerr@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char const*, unsigned long, unsigned long) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::str() const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::invalid_argument@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_ostringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_rebalance_for_erase(std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_hook(std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_unhook()@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::_M_sync(char*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<char, std::char_traits<char> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `log2f@GLIBC_2.2.5'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::exception::~exception()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_create(unsigned long, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::is_open() const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_istringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::swap(std::string&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::basic_streambuf(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::init(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_cast()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::clear(std::_Ios_Iostate)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::operator=(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete(void*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(int)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_M_destroy(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<wchar_t, std::char_traits<wchar_t> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::runtime_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_insert_and_rebalance(bool, std::_Rb_tree_node_base*, std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_stringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long>(long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::get()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::operator<< <std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::out_of_range::~out_of_range()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::length_error::~length_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::__ostream_insert<char, std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*, long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::invalid_argument::~invalid_argument()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::swap(std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cout@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long long>(unsigned long long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for int*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<void const*>(void const*)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::underflow_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::out_of_range@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_allocate_exception@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void const*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::init(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::reserve(unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_begin_catch@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::open(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_M_sync(wchar_t*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::getline(char*, long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istream<char, std::char_traits<char> >& std::getline<char, std::char_traits<char>, std::allocator<char> >(std::basic_istream<char, std::char_traits<char> >&, std::basic_string<char, std::char_traits<char>, std::allocator<char> >&, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::~condition_variable()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::insert(unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned char@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::ios_base()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_out_of_range(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::overflow_error::~overflow_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_length_error(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_system_error(int)@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ofstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<double>(double)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::operator=(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(char const*, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_release@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_throw@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::underflow_error::~underflow_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::length_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::~basic_filebuf()@GLIBCXX_3.4'\n",
      "collect2: error: ld returned 1 exit status\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=20,\n",
    "    weight_decay=0.20,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got mask position:  tensor(-2, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
      "  warnings.warn(\n",
      "W0427 23:23:24.448000 1286876 torch/_inductor/utils.py:1137] [1/0] Not enough SMs to use max_autotune_gemm mode\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7200' max='7200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7200/7200 12:55, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>7.154400</td>\n",
       "      <td>2.190254</td>\n",
       "      <td>1.243347</td>\n",
       "      <td>2.190254</td>\n",
       "      <td>1.479951</td>\n",
       "      <td>0.050667</td>\n",
       "      <td>0.026535</td>\n",
       "      <td>0.301591</td>\n",
       "      <td>0.290068</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.469000</td>\n",
       "      <td>1.053077</td>\n",
       "      <td>0.802079</td>\n",
       "      <td>1.053077</td>\n",
       "      <td>1.026196</td>\n",
       "      <td>0.094000</td>\n",
       "      <td>0.531957</td>\n",
       "      <td>0.787756</td>\n",
       "      <td>0.795585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.823000</td>\n",
       "      <td>0.723052</td>\n",
       "      <td>0.669836</td>\n",
       "      <td>0.723052</td>\n",
       "      <td>0.850325</td>\n",
       "      <td>0.076000</td>\n",
       "      <td>0.678637</td>\n",
       "      <td>0.825526</td>\n",
       "      <td>0.826381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.602900</td>\n",
       "      <td>0.754829</td>\n",
       "      <td>0.674474</td>\n",
       "      <td>0.754829</td>\n",
       "      <td>0.868809</td>\n",
       "      <td>0.103333</td>\n",
       "      <td>0.664514</td>\n",
       "      <td>0.842459</td>\n",
       "      <td>0.844564</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.527500</td>\n",
       "      <td>0.647487</td>\n",
       "      <td>0.632452</td>\n",
       "      <td>0.647487</td>\n",
       "      <td>0.804666</td>\n",
       "      <td>0.104000</td>\n",
       "      <td>0.712222</td>\n",
       "      <td>0.855951</td>\n",
       "      <td>0.856283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.583400</td>\n",
       "      <td>0.676577</td>\n",
       "      <td>0.642674</td>\n",
       "      <td>0.676577</td>\n",
       "      <td>0.822543</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>0.699293</td>\n",
       "      <td>0.854286</td>\n",
       "      <td>0.856457</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.542900</td>\n",
       "      <td>0.585194</td>\n",
       "      <td>0.616617</td>\n",
       "      <td>0.585194</td>\n",
       "      <td>0.764980</td>\n",
       "      <td>0.101333</td>\n",
       "      <td>0.739909</td>\n",
       "      <td>0.872589</td>\n",
       "      <td>0.871478</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.475900</td>\n",
       "      <td>0.551545</td>\n",
       "      <td>0.569348</td>\n",
       "      <td>0.551545</td>\n",
       "      <td>0.742661</td>\n",
       "      <td>0.136667</td>\n",
       "      <td>0.754864</td>\n",
       "      <td>0.871330</td>\n",
       "      <td>0.870365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.432700</td>\n",
       "      <td>0.531739</td>\n",
       "      <td>0.569416</td>\n",
       "      <td>0.531739</td>\n",
       "      <td>0.729204</td>\n",
       "      <td>0.117333</td>\n",
       "      <td>0.763667</td>\n",
       "      <td>0.873973</td>\n",
       "      <td>0.874463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.460600</td>\n",
       "      <td>0.547541</td>\n",
       "      <td>0.568494</td>\n",
       "      <td>0.547541</td>\n",
       "      <td>0.739960</td>\n",
       "      <td>0.124000</td>\n",
       "      <td>0.756644</td>\n",
       "      <td>0.873818</td>\n",
       "      <td>0.873360</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.423100</td>\n",
       "      <td>0.504307</td>\n",
       "      <td>0.550293</td>\n",
       "      <td>0.504307</td>\n",
       "      <td>0.710146</td>\n",
       "      <td>0.122667</td>\n",
       "      <td>0.775859</td>\n",
       "      <td>0.881509</td>\n",
       "      <td>0.879942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.375400</td>\n",
       "      <td>0.651221</td>\n",
       "      <td>0.619447</td>\n",
       "      <td>0.651221</td>\n",
       "      <td>0.806982</td>\n",
       "      <td>0.120000</td>\n",
       "      <td>0.710563</td>\n",
       "      <td>0.877952</td>\n",
       "      <td>0.880401</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.406500</td>\n",
       "      <td>0.557294</td>\n",
       "      <td>0.566503</td>\n",
       "      <td>0.557294</td>\n",
       "      <td>0.746522</td>\n",
       "      <td>0.139333</td>\n",
       "      <td>0.752309</td>\n",
       "      <td>0.878486</td>\n",
       "      <td>0.879551</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.394500</td>\n",
       "      <td>0.515691</td>\n",
       "      <td>0.561052</td>\n",
       "      <td>0.515691</td>\n",
       "      <td>0.718116</td>\n",
       "      <td>0.107333</td>\n",
       "      <td>0.770799</td>\n",
       "      <td>0.879279</td>\n",
       "      <td>0.879801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.320600</td>\n",
       "      <td>0.521563</td>\n",
       "      <td>0.554561</td>\n",
       "      <td>0.521563</td>\n",
       "      <td>0.722193</td>\n",
       "      <td>0.124667</td>\n",
       "      <td>0.768190</td>\n",
       "      <td>0.879346</td>\n",
       "      <td>0.878359</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.326400</td>\n",
       "      <td>0.544826</td>\n",
       "      <td>0.571863</td>\n",
       "      <td>0.544826</td>\n",
       "      <td>0.738123</td>\n",
       "      <td>0.102667</td>\n",
       "      <td>0.757851</td>\n",
       "      <td>0.870881</td>\n",
       "      <td>0.870574</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.331500</td>\n",
       "      <td>0.484127</td>\n",
       "      <td>0.544579</td>\n",
       "      <td>0.484127</td>\n",
       "      <td>0.695792</td>\n",
       "      <td>0.112667</td>\n",
       "      <td>0.784828</td>\n",
       "      <td>0.886991</td>\n",
       "      <td>0.885255</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.317200</td>\n",
       "      <td>0.502000</td>\n",
       "      <td>0.558664</td>\n",
       "      <td>0.502000</td>\n",
       "      <td>0.708520</td>\n",
       "      <td>0.107333</td>\n",
       "      <td>0.776884</td>\n",
       "      <td>0.883438</td>\n",
       "      <td>0.882153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.284100</td>\n",
       "      <td>0.505208</td>\n",
       "      <td>0.558569</td>\n",
       "      <td>0.505208</td>\n",
       "      <td>0.710780</td>\n",
       "      <td>0.103333</td>\n",
       "      <td>0.775459</td>\n",
       "      <td>0.880907</td>\n",
       "      <td>0.881753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.278300</td>\n",
       "      <td>0.526476</td>\n",
       "      <td>0.567629</td>\n",
       "      <td>0.526476</td>\n",
       "      <td>0.725587</td>\n",
       "      <td>0.103333</td>\n",
       "      <td>0.766006</td>\n",
       "      <td>0.885574</td>\n",
       "      <td>0.884904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.264300</td>\n",
       "      <td>0.482404</td>\n",
       "      <td>0.536266</td>\n",
       "      <td>0.482404</td>\n",
       "      <td>0.694553</td>\n",
       "      <td>0.111333</td>\n",
       "      <td>0.785594</td>\n",
       "      <td>0.886581</td>\n",
       "      <td>0.885687</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.258000</td>\n",
       "      <td>0.510820</td>\n",
       "      <td>0.550182</td>\n",
       "      <td>0.510820</td>\n",
       "      <td>0.714717</td>\n",
       "      <td>0.121333</td>\n",
       "      <td>0.772964</td>\n",
       "      <td>0.882083</td>\n",
       "      <td>0.881782</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.248300</td>\n",
       "      <td>0.584664</td>\n",
       "      <td>0.572042</td>\n",
       "      <td>0.584664</td>\n",
       "      <td>0.764633</td>\n",
       "      <td>0.148667</td>\n",
       "      <td>0.740144</td>\n",
       "      <td>0.879496</td>\n",
       "      <td>0.881568</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.237200</td>\n",
       "      <td>0.529163</td>\n",
       "      <td>0.560523</td>\n",
       "      <td>0.529163</td>\n",
       "      <td>0.727436</td>\n",
       "      <td>0.118667</td>\n",
       "      <td>0.764812</td>\n",
       "      <td>0.877754</td>\n",
       "      <td>0.880645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.263300</td>\n",
       "      <td>0.527309</td>\n",
       "      <td>0.557505</td>\n",
       "      <td>0.527309</td>\n",
       "      <td>0.726160</td>\n",
       "      <td>0.120000</td>\n",
       "      <td>0.765636</td>\n",
       "      <td>0.877129</td>\n",
       "      <td>0.877122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.224200</td>\n",
       "      <td>0.506305</td>\n",
       "      <td>0.546039</td>\n",
       "      <td>0.506305</td>\n",
       "      <td>0.711551</td>\n",
       "      <td>0.120667</td>\n",
       "      <td>0.774971</td>\n",
       "      <td>0.881240</td>\n",
       "      <td>0.880124</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.199500</td>\n",
       "      <td>0.552447</td>\n",
       "      <td>0.564897</td>\n",
       "      <td>0.552447</td>\n",
       "      <td>0.743268</td>\n",
       "      <td>0.126667</td>\n",
       "      <td>0.754463</td>\n",
       "      <td>0.878646</td>\n",
       "      <td>0.879236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.203900</td>\n",
       "      <td>0.515472</td>\n",
       "      <td>0.551336</td>\n",
       "      <td>0.515472</td>\n",
       "      <td>0.717964</td>\n",
       "      <td>0.111333</td>\n",
       "      <td>0.770897</td>\n",
       "      <td>0.880083</td>\n",
       "      <td>0.879631</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.210200</td>\n",
       "      <td>0.492666</td>\n",
       "      <td>0.540866</td>\n",
       "      <td>0.492666</td>\n",
       "      <td>0.701902</td>\n",
       "      <td>0.122000</td>\n",
       "      <td>0.781033</td>\n",
       "      <td>0.884856</td>\n",
       "      <td>0.882916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.177800</td>\n",
       "      <td>0.542559</td>\n",
       "      <td>0.561245</td>\n",
       "      <td>0.542559</td>\n",
       "      <td>0.736586</td>\n",
       "      <td>0.126000</td>\n",
       "      <td>0.758858</td>\n",
       "      <td>0.881206</td>\n",
       "      <td>0.880576</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.183000</td>\n",
       "      <td>0.533257</td>\n",
       "      <td>0.558243</td>\n",
       "      <td>0.533257</td>\n",
       "      <td>0.730244</td>\n",
       "      <td>0.125333</td>\n",
       "      <td>0.762992</td>\n",
       "      <td>0.878279</td>\n",
       "      <td>0.876630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.183700</td>\n",
       "      <td>0.520696</td>\n",
       "      <td>0.560726</td>\n",
       "      <td>0.520696</td>\n",
       "      <td>0.721593</td>\n",
       "      <td>0.109333</td>\n",
       "      <td>0.768575</td>\n",
       "      <td>0.876916</td>\n",
       "      <td>0.876042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.163000</td>\n",
       "      <td>0.514440</td>\n",
       "      <td>0.552218</td>\n",
       "      <td>0.514440</td>\n",
       "      <td>0.717245</td>\n",
       "      <td>0.112667</td>\n",
       "      <td>0.771355</td>\n",
       "      <td>0.880844</td>\n",
       "      <td>0.880647</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.152400</td>\n",
       "      <td>0.518377</td>\n",
       "      <td>0.547779</td>\n",
       "      <td>0.518377</td>\n",
       "      <td>0.719984</td>\n",
       "      <td>0.139333</td>\n",
       "      <td>0.769606</td>\n",
       "      <td>0.879742</td>\n",
       "      <td>0.877466</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.145400</td>\n",
       "      <td>0.525892</td>\n",
       "      <td>0.559663</td>\n",
       "      <td>0.525892</td>\n",
       "      <td>0.725184</td>\n",
       "      <td>0.112667</td>\n",
       "      <td>0.766266</td>\n",
       "      <td>0.876199</td>\n",
       "      <td>0.874319</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.161700</td>\n",
       "      <td>0.532127</td>\n",
       "      <td>0.556427</td>\n",
       "      <td>0.532127</td>\n",
       "      <td>0.729470</td>\n",
       "      <td>0.128667</td>\n",
       "      <td>0.763495</td>\n",
       "      <td>0.876428</td>\n",
       "      <td>0.875616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.141100</td>\n",
       "      <td>0.542709</td>\n",
       "      <td>0.555228</td>\n",
       "      <td>0.542709</td>\n",
       "      <td>0.736688</td>\n",
       "      <td>0.145333</td>\n",
       "      <td>0.758792</td>\n",
       "      <td>0.877213</td>\n",
       "      <td>0.875979</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.134600</td>\n",
       "      <td>0.536413</td>\n",
       "      <td>0.556224</td>\n",
       "      <td>0.536413</td>\n",
       "      <td>0.732402</td>\n",
       "      <td>0.138000</td>\n",
       "      <td>0.761590</td>\n",
       "      <td>0.876099</td>\n",
       "      <td>0.874720</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.127200</td>\n",
       "      <td>0.531028</td>\n",
       "      <td>0.557004</td>\n",
       "      <td>0.531028</td>\n",
       "      <td>0.728716</td>\n",
       "      <td>0.125333</td>\n",
       "      <td>0.763983</td>\n",
       "      <td>0.875777</td>\n",
       "      <td>0.875262</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.117200</td>\n",
       "      <td>0.524613</td>\n",
       "      <td>0.556489</td>\n",
       "      <td>0.524613</td>\n",
       "      <td>0.724301</td>\n",
       "      <td>0.112667</td>\n",
       "      <td>0.766834</td>\n",
       "      <td>0.877367</td>\n",
       "      <td>0.874785</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.114100</td>\n",
       "      <td>0.528294</td>\n",
       "      <td>0.554472</td>\n",
       "      <td>0.528294</td>\n",
       "      <td>0.726839</td>\n",
       "      <td>0.126667</td>\n",
       "      <td>0.765198</td>\n",
       "      <td>0.875646</td>\n",
       "      <td>0.873788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.119000</td>\n",
       "      <td>0.548619</td>\n",
       "      <td>0.562534</td>\n",
       "      <td>0.548619</td>\n",
       "      <td>0.740688</td>\n",
       "      <td>0.141333</td>\n",
       "      <td>0.756165</td>\n",
       "      <td>0.874175</td>\n",
       "      <td>0.873807</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.116300</td>\n",
       "      <td>0.552451</td>\n",
       "      <td>0.562544</td>\n",
       "      <td>0.552451</td>\n",
       "      <td>0.743271</td>\n",
       "      <td>0.124667</td>\n",
       "      <td>0.754461</td>\n",
       "      <td>0.873151</td>\n",
       "      <td>0.872409</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.100200</td>\n",
       "      <td>0.538988</td>\n",
       "      <td>0.559493</td>\n",
       "      <td>0.538988</td>\n",
       "      <td>0.734158</td>\n",
       "      <td>0.117333</td>\n",
       "      <td>0.760445</td>\n",
       "      <td>0.873354</td>\n",
       "      <td>0.871807</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.103300</td>\n",
       "      <td>0.547710</td>\n",
       "      <td>0.563381</td>\n",
       "      <td>0.547710</td>\n",
       "      <td>0.740074</td>\n",
       "      <td>0.129333</td>\n",
       "      <td>0.756569</td>\n",
       "      <td>0.873690</td>\n",
       "      <td>0.872097</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.100500</td>\n",
       "      <td>0.561588</td>\n",
       "      <td>0.566864</td>\n",
       "      <td>0.561588</td>\n",
       "      <td>0.749392</td>\n",
       "      <td>0.140667</td>\n",
       "      <td>0.750400</td>\n",
       "      <td>0.872079</td>\n",
       "      <td>0.871390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.090600</td>\n",
       "      <td>0.557820</td>\n",
       "      <td>0.565044</td>\n",
       "      <td>0.557820</td>\n",
       "      <td>0.746874</td>\n",
       "      <td>0.148667</td>\n",
       "      <td>0.752075</td>\n",
       "      <td>0.871800</td>\n",
       "      <td>0.871014</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.090400</td>\n",
       "      <td>0.547166</td>\n",
       "      <td>0.565690</td>\n",
       "      <td>0.547166</td>\n",
       "      <td>0.739707</td>\n",
       "      <td>0.115333</td>\n",
       "      <td>0.756810</td>\n",
       "      <td>0.871631</td>\n",
       "      <td>0.871481</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.084300</td>\n",
       "      <td>0.552771</td>\n",
       "      <td>0.564009</td>\n",
       "      <td>0.552771</td>\n",
       "      <td>0.743485</td>\n",
       "      <td>0.131333</td>\n",
       "      <td>0.754319</td>\n",
       "      <td>0.872622</td>\n",
       "      <td>0.870825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.087300</td>\n",
       "      <td>0.552669</td>\n",
       "      <td>0.563479</td>\n",
       "      <td>0.552669</td>\n",
       "      <td>0.743417</td>\n",
       "      <td>0.130667</td>\n",
       "      <td>0.754365</td>\n",
       "      <td>0.873735</td>\n",
       "      <td>0.872839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.078900</td>\n",
       "      <td>0.554395</td>\n",
       "      <td>0.564463</td>\n",
       "      <td>0.554395</td>\n",
       "      <td>0.744577</td>\n",
       "      <td>0.141333</td>\n",
       "      <td>0.753597</td>\n",
       "      <td>0.871577</td>\n",
       "      <td>0.870133</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.075400</td>\n",
       "      <td>0.558463</td>\n",
       "      <td>0.567063</td>\n",
       "      <td>0.558463</td>\n",
       "      <td>0.747303</td>\n",
       "      <td>0.130000</td>\n",
       "      <td>0.751790</td>\n",
       "      <td>0.870591</td>\n",
       "      <td>0.869781</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.078900</td>\n",
       "      <td>0.566578</td>\n",
       "      <td>0.570191</td>\n",
       "      <td>0.566578</td>\n",
       "      <td>0.752714</td>\n",
       "      <td>0.132667</td>\n",
       "      <td>0.748182</td>\n",
       "      <td>0.871142</td>\n",
       "      <td>0.869452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.078500</td>\n",
       "      <td>0.549590</td>\n",
       "      <td>0.562868</td>\n",
       "      <td>0.549590</td>\n",
       "      <td>0.741343</td>\n",
       "      <td>0.132000</td>\n",
       "      <td>0.755733</td>\n",
       "      <td>0.871426</td>\n",
       "      <td>0.869621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.065300</td>\n",
       "      <td>0.567219</td>\n",
       "      <td>0.569186</td>\n",
       "      <td>0.567219</td>\n",
       "      <td>0.753139</td>\n",
       "      <td>0.141333</td>\n",
       "      <td>0.747898</td>\n",
       "      <td>0.869584</td>\n",
       "      <td>0.868101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.071100</td>\n",
       "      <td>0.562108</td>\n",
       "      <td>0.568580</td>\n",
       "      <td>0.562108</td>\n",
       "      <td>0.749739</td>\n",
       "      <td>0.132000</td>\n",
       "      <td>0.750169</td>\n",
       "      <td>0.872439</td>\n",
       "      <td>0.870155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.068100</td>\n",
       "      <td>0.557507</td>\n",
       "      <td>0.565841</td>\n",
       "      <td>0.557507</td>\n",
       "      <td>0.746664</td>\n",
       "      <td>0.131333</td>\n",
       "      <td>0.752214</td>\n",
       "      <td>0.871167</td>\n",
       "      <td>0.869313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.071500</td>\n",
       "      <td>0.561335</td>\n",
       "      <td>0.569306</td>\n",
       "      <td>0.561335</td>\n",
       "      <td>0.749223</td>\n",
       "      <td>0.135333</td>\n",
       "      <td>0.750513</td>\n",
       "      <td>0.868898</td>\n",
       "      <td>0.867728</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.062700</td>\n",
       "      <td>0.564593</td>\n",
       "      <td>0.568967</td>\n",
       "      <td>0.564593</td>\n",
       "      <td>0.751394</td>\n",
       "      <td>0.142000</td>\n",
       "      <td>0.749065</td>\n",
       "      <td>0.869317</td>\n",
       "      <td>0.867817</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.063700</td>\n",
       "      <td>0.572849</td>\n",
       "      <td>0.570399</td>\n",
       "      <td>0.572849</td>\n",
       "      <td>0.756868</td>\n",
       "      <td>0.145333</td>\n",
       "      <td>0.745395</td>\n",
       "      <td>0.870083</td>\n",
       "      <td>0.868827</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.068900</td>\n",
       "      <td>0.562078</td>\n",
       "      <td>0.567286</td>\n",
       "      <td>0.562078</td>\n",
       "      <td>0.749719</td>\n",
       "      <td>0.139333</td>\n",
       "      <td>0.750183</td>\n",
       "      <td>0.869748</td>\n",
       "      <td>0.868205</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.059200</td>\n",
       "      <td>0.562989</td>\n",
       "      <td>0.568681</td>\n",
       "      <td>0.562988</td>\n",
       "      <td>0.750326</td>\n",
       "      <td>0.140667</td>\n",
       "      <td>0.749778</td>\n",
       "      <td>0.870154</td>\n",
       "      <td>0.868086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.061500</td>\n",
       "      <td>0.569294</td>\n",
       "      <td>0.570890</td>\n",
       "      <td>0.569294</td>\n",
       "      <td>0.754516</td>\n",
       "      <td>0.140667</td>\n",
       "      <td>0.746975</td>\n",
       "      <td>0.869102</td>\n",
       "      <td>0.867712</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6400</td>\n",
       "      <td>0.062200</td>\n",
       "      <td>0.564797</td>\n",
       "      <td>0.569034</td>\n",
       "      <td>0.564797</td>\n",
       "      <td>0.751530</td>\n",
       "      <td>0.143333</td>\n",
       "      <td>0.748974</td>\n",
       "      <td>0.869478</td>\n",
       "      <td>0.868010</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>0.058500</td>\n",
       "      <td>0.570633</td>\n",
       "      <td>0.570265</td>\n",
       "      <td>0.570633</td>\n",
       "      <td>0.755403</td>\n",
       "      <td>0.146667</td>\n",
       "      <td>0.746380</td>\n",
       "      <td>0.869145</td>\n",
       "      <td>0.867793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6600</td>\n",
       "      <td>0.058100</td>\n",
       "      <td>0.568287</td>\n",
       "      <td>0.570281</td>\n",
       "      <td>0.568287</td>\n",
       "      <td>0.753848</td>\n",
       "      <td>0.141333</td>\n",
       "      <td>0.747423</td>\n",
       "      <td>0.868826</td>\n",
       "      <td>0.867286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6700</td>\n",
       "      <td>0.060900</td>\n",
       "      <td>0.567215</td>\n",
       "      <td>0.569894</td>\n",
       "      <td>0.567215</td>\n",
       "      <td>0.753137</td>\n",
       "      <td>0.141333</td>\n",
       "      <td>0.747900</td>\n",
       "      <td>0.869064</td>\n",
       "      <td>0.867645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6800</td>\n",
       "      <td>0.057100</td>\n",
       "      <td>0.571982</td>\n",
       "      <td>0.570973</td>\n",
       "      <td>0.571982</td>\n",
       "      <td>0.756295</td>\n",
       "      <td>0.144667</td>\n",
       "      <td>0.745781</td>\n",
       "      <td>0.869263</td>\n",
       "      <td>0.867977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6900</td>\n",
       "      <td>0.058200</td>\n",
       "      <td>0.568954</td>\n",
       "      <td>0.570190</td>\n",
       "      <td>0.568954</td>\n",
       "      <td>0.754291</td>\n",
       "      <td>0.139333</td>\n",
       "      <td>0.747126</td>\n",
       "      <td>0.869171</td>\n",
       "      <td>0.867770</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>0.057000</td>\n",
       "      <td>0.568842</td>\n",
       "      <td>0.569953</td>\n",
       "      <td>0.568842</td>\n",
       "      <td>0.754216</td>\n",
       "      <td>0.142000</td>\n",
       "      <td>0.747176</td>\n",
       "      <td>0.869090</td>\n",
       "      <td>0.867630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7100</td>\n",
       "      <td>0.058700</td>\n",
       "      <td>0.569936</td>\n",
       "      <td>0.570627</td>\n",
       "      <td>0.569936</td>\n",
       "      <td>0.754941</td>\n",
       "      <td>0.142000</td>\n",
       "      <td>0.746690</td>\n",
       "      <td>0.869007</td>\n",
       "      <td>0.867488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7200</td>\n",
       "      <td>0.059400</td>\n",
       "      <td>0.569495</td>\n",
       "      <td>0.570374</td>\n",
       "      <td>0.569495</td>\n",
       "      <td>0.754649</td>\n",
       "      <td>0.142667</td>\n",
       "      <td>0.746886</td>\n",
       "      <td>0.869080</td>\n",
       "      <td>0.867620</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=7200, training_loss=0.31619614442189536, metrics={'train_runtime': 783.8549, 'train_samples_per_second': 146.685, 'train_steps_per_second': 9.185, 'total_flos': 36318700243848.0, 'train_loss': 0.31619614442189536, 'epoch': 20.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54c97e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4e83df",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
