{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "raw_datasets = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/social_i_qa/train.json')\n",
    "\n",
    "val_datasets = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/social_i_qa/test.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7756a810",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['instruction', 'input', 'output', 'answer'],\n",
       "        num_rows: 1954\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "769a952a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "\n",
    "\n",
    "# Initialize and fit label encoder\n",
    "label_encoder = LabelEncoder()\n",
    "label_encoder.fit(raw_datasets['train']['answer'])\n",
    "# Create the integer labels\n",
    "train_labels = label_encoder.transform(raw_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "raw_datasets['train'] = raw_datasets['train'].add_column('labels', train_labels)\n",
    "\n",
    "val_labels = label_encoder.transform(val_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "val_datasets['train'] = val_datasets['train'].add_column('labels', val_labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9badafa3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer3',\n",
       " 'answer1',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " 'answer1',\n",
       " 'answer2',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer3',\n",
       " 'answer2',\n",
       " 'answer2',\n",
       " ...]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_datasets['train']['answer']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 33410\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 1954\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    # sorry about the formatting disaster gotta move fast\n",
    "    return f\"\"\"# input: {data_point[\"instruction\"].split('format:')[0]}:{mask_token}\"\"\"\n",
    "               \n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = example['labels']\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = val_datasets['train'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# input: Please choose the correct answer to the question: Cameron decided to have a barbecue and gathered her friends together. How would Others feel as a result?\\n\\nAnswer1: like attending Answer2: like staying home Answer3: a good friend to have\\n\\nAnswer :[MASK]'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data['input'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['instruction', 'input', 'output', 'answer']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] # input: Please choose the correct answer to the question: Sydney was a school teacher and made sure their students learned well. How would you describe Sydney? Answer1: As someone that asked for a job Answer2: As someone that takes teaching seriously Answer3: Like a leader Answer :[MASK][SEP]'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
       "    num_rows: 1954\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "176"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f1005af8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total labels (classes): 3\n"
     ]
    }
   ],
   "source": [
    "num_labels = len(label_encoder.classes_)\n",
    "print(f\"Total labels (classes): {num_labels}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForMaskedLM were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "from modeling import MLMSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "\n",
    "model = MLMSequenceClassification.from_pretrained(model_name, config=config, num_labels=num_labels, mask_token_id=tokenizer.mask_token_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import evaluate\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "\n",
    "\n",
    "    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    precision = metrics.precision_score(labels, predictions, average=\"macro\")\n",
    "    recall = metrics.recall_score(labels, predictions, average=\"macro\")\n",
    "    f1 = metrics.f1_score(labels, predictions, average=\"macro\")\n",
    "    accuracy = metrics.accuracy_score(labels, predictions)\n",
    "    \n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1-score\": f1, 'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-28 11:40:43,746] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libpthread.so.0, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libstdc++.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: warning: libm.so.6, needed by /usr/local/cuda/lib64/libcufile.so, not found (try using -rpath or -rpath-link)\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for bool@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_logic_error(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::logic_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::~locale()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_end_catch@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::logic_error::~logic_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__si_class_type_info@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new[](unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak_hard()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::basic_streambuf(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned short@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::resize(unsigned long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char const*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ctype<char>::_M_widen_init() const@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_invalid_argument(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::operator=(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::_M_cache_locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_free_exception@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::notify_one()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::~Init()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::~basic_string()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_pure_virtual@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::flush()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for __cxxabiv1::__class_type_info@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_rethrow@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_fstream<char, std::char_traits<char> >::~basic_fstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::compare(char const*) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::system_clock::now()@GLIBCXX_3.4.19'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_ifstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Hash_bytes(void const*, unsigned long, unsigned long)@CXXABI_1.3.5'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long long>(long long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for char*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const@GLIBCXX_3.4.18'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::out_of_range@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long>(unsigned long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::~ios_base()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::range_error::~range_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::~__basic_file()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_acquire@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<bool>(bool)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::overflow_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::range_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_filebuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete[](void*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(unsigned long, char, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_transfer(std::__detail::_List_node_base*, std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::replace(unsigned long, unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::exception@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_M_destroy(std::allocator<wchar_t> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream& std::istream::_M_extract<double>(double&)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_fstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::basic_ifstream(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator new(unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned int@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::append(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::domain_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char, unsigned long) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::put(char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for int@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_alloc()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_thread_atexit@CXXABI_1.3.7'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_increment(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::~basic_ifstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::Init::Init()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::condition_variable()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::basic_filebuf()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::domain_error::~domain_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cerr@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::find(char const*, unsigned long, unsigned long) const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::str() const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::invalid_argument@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(std::string const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_ostringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_rebalance_for_erase(std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_hook(std::__detail::_List_node_base*)@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__detail::_List_node_base::_M_unhook()@GLIBCXX_3.4.15'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<char, std::char_traits<char>, std::allocator<char> >::_M_sync(char*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<char, std::char_traits<char> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::locale::locale(std::locale const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_istringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `log2f@GLIBC_2.2.5'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::exception::~exception()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_create(unsigned long, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__basic_file<char>::is_open() const@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_istringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::swap(std::string&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ostringstream<char, std::char_traits<char>, std::allocator<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::basic_streambuf(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::init(std::basic_streambuf<char, std::char_traits<char> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_bad_cast()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<char, std::char_traits<char> >::clear(std::_Ios_Iostate)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >::operator=(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `operator delete(void*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::operator<<(int)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_Rep::_M_destroy(std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_iostream<wchar_t, std::char_traits<wchar_t> >::~basic_iostream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::runtime_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ofstream<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_insert_and_rebalance(bool, std::_Rb_tree_node_base*, std::_Rb_tree_node_base*, std::_Rb_tree_node_base&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringstream<char, std::char_traits<char>, std::allocator<char> >::~basic_stringstream()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `VTT for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<long>(long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::get()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned long long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::operator<< <std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::out_of_range::~out_of_range()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::length_error::~length_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ostream<char, std::char_traits<char> >& std::__ostream_insert<char, std::char_traits<char> >(std::basic_ostream<char, std::char_traits<char> >&, char const*, long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::invalid_argument::~invalid_argument()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::swap(std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::cout@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<unsigned long long>(unsigned long long)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for int*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<void const*>(void const*)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::underflow_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_streambuf<char, std::char_traits<char> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for std::out_of_range@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_allocate_exception@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_ios<wchar_t, std::char_traits<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for void const*@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ios<wchar_t, std::char_traits<wchar_t> >::init(std::basic_streambuf<wchar_t, std::char_traits<wchar_t> >*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::reserve(unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_begin_catch@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_Rep::_S_empty_rep_storage@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_leak()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::open(char const*, std::_Ios_Openmode)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >::_M_sync(wchar_t*, unsigned long, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::istream::getline(char*, long, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_istream<char, std::char_traits<char> >& std::getline<char, std::char_traits<char>, std::allocator<char> >(std::basic_istream<char, std::char_traits<char> >&, std::basic_string<char, std::char_traits<char>, std::allocator<char> >&, char)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringstream<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::condition_variable::~condition_variable()@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::basic_stringbuf<wchar_t, std::char_traits<wchar_t>, std::allocator<wchar_t> >@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::insert(unsigned long, char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::assign(char const*, unsigned long)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for unsigned char@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ios_base::ios_base()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_out_of_range(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::overflow_error::~overflow_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_length_error(char const*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::__throw_system_error(int)@GLIBCXX_3.4.11'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ofstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream& std::ostream::_M_insert<double>(double)@GLIBCXX_3.4.9'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_streambuf<char, std::char_traits<char> >::operator=(std::basic_streambuf<char, std::char_traits<char> > const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for long long@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_string<char, std::char_traits<char>, std::allocator<char> >::basic_string(char const*, unsigned long, std::allocator<char> const&)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_ifstream<char, std::char_traits<char> >::close()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_guard_release@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__cxa_throw@CXXABI_1.3'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::underflow_error::~underflow_error()@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::_Rb_tree_decrement(std::_Rb_tree_node_base*)@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `vtable for std::length_error@GLIBCXX_3.4'\n",
      "/home/guangyu/anaconda3/envs/MD/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::basic_filebuf<char, std::char_traits<char> >::~basic_filebuf()@GLIBCXX_3.4'\n",
      "collect2: error: ld returned 1 exit status\n"
     ]
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=3e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=4,\n",
    "    weight_decay=0.0,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got mask position:  tensor(-2, device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='8356' max='8356' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [8356/8356 31:20, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.114700</td>\n",
       "      <td>1.107114</td>\n",
       "      <td>0.342712</td>\n",
       "      <td>0.333627</td>\n",
       "      <td>0.187013</td>\n",
       "      <td>0.329580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.114100</td>\n",
       "      <td>1.105584</td>\n",
       "      <td>0.111566</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.167178</td>\n",
       "      <td>0.334698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.109800</td>\n",
       "      <td>1.109040</td>\n",
       "      <td>0.109690</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.165062</td>\n",
       "      <td>0.329069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>1.105700</td>\n",
       "      <td>1.099431</td>\n",
       "      <td>0.329972</td>\n",
       "      <td>0.335382</td>\n",
       "      <td>0.304918</td>\n",
       "      <td>0.334698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.105100</td>\n",
       "      <td>1.100817</td>\n",
       "      <td>0.225850</td>\n",
       "      <td>0.330564</td>\n",
       "      <td>0.226273</td>\n",
       "      <td>0.332139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>1.105200</td>\n",
       "      <td>1.099334</td>\n",
       "      <td>0.300828</td>\n",
       "      <td>0.324112</td>\n",
       "      <td>0.257522</td>\n",
       "      <td>0.325998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>1.107300</td>\n",
       "      <td>1.101185</td>\n",
       "      <td>0.241783</td>\n",
       "      <td>0.340378</td>\n",
       "      <td>0.207851</td>\n",
       "      <td>0.341351</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>1.101900</td>\n",
       "      <td>1.100718</td>\n",
       "      <td>0.202071</td>\n",
       "      <td>0.331835</td>\n",
       "      <td>0.176444</td>\n",
       "      <td>0.334698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>1.105400</td>\n",
       "      <td>1.098710</td>\n",
       "      <td>0.335203</td>\n",
       "      <td>0.327924</td>\n",
       "      <td>0.311808</td>\n",
       "      <td>0.329069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.100500</td>\n",
       "      <td>1.102134</td>\n",
       "      <td>0.213597</td>\n",
       "      <td>0.327011</td>\n",
       "      <td>0.249164</td>\n",
       "      <td>0.324463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>1.103300</td>\n",
       "      <td>1.101851</td>\n",
       "      <td>0.214190</td>\n",
       "      <td>0.333324</td>\n",
       "      <td>0.170968</td>\n",
       "      <td>0.334698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>1.101400</td>\n",
       "      <td>1.099119</td>\n",
       "      <td>0.314787</td>\n",
       "      <td>0.329059</td>\n",
       "      <td>0.271701</td>\n",
       "      <td>0.328557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>1.096800</td>\n",
       "      <td>1.103292</td>\n",
       "      <td>0.109690</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.165062</td>\n",
       "      <td>0.329069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>1.102800</td>\n",
       "      <td>1.104506</td>\n",
       "      <td>0.302957</td>\n",
       "      <td>0.335407</td>\n",
       "      <td>0.172305</td>\n",
       "      <td>0.338280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>1.100500</td>\n",
       "      <td>1.101326</td>\n",
       "      <td>0.109690</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.165062</td>\n",
       "      <td>0.329069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>1.100700</td>\n",
       "      <td>1.101160</td>\n",
       "      <td>0.219659</td>\n",
       "      <td>0.336124</td>\n",
       "      <td>0.189228</td>\n",
       "      <td>0.338792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>1.102900</td>\n",
       "      <td>1.100033</td>\n",
       "      <td>0.227134</td>\n",
       "      <td>0.339657</td>\n",
       "      <td>0.262444</td>\n",
       "      <td>0.340328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>1.099300</td>\n",
       "      <td>1.098276</td>\n",
       "      <td>0.338877</td>\n",
       "      <td>0.332502</td>\n",
       "      <td>0.284251</td>\n",
       "      <td>0.334186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>1.101300</td>\n",
       "      <td>1.098462</td>\n",
       "      <td>0.462577</td>\n",
       "      <td>0.337996</td>\n",
       "      <td>0.182549</td>\n",
       "      <td>0.339304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>1.100500</td>\n",
       "      <td>1.099126</td>\n",
       "      <td>0.224912</td>\n",
       "      <td>0.340171</td>\n",
       "      <td>0.269867</td>\n",
       "      <td>0.338280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>1.099800</td>\n",
       "      <td>1.098133</td>\n",
       "      <td>0.348394</td>\n",
       "      <td>0.335843</td>\n",
       "      <td>0.286755</td>\n",
       "      <td>0.337769</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>1.101000</td>\n",
       "      <td>1.098229</td>\n",
       "      <td>0.343807</td>\n",
       "      <td>0.342097</td>\n",
       "      <td>0.295083</td>\n",
       "      <td>0.343398</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>1.099400</td>\n",
       "      <td>1.102143</td>\n",
       "      <td>0.229565</td>\n",
       "      <td>0.334392</td>\n",
       "      <td>0.184753</td>\n",
       "      <td>0.337257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>1.097500</td>\n",
       "      <td>1.103527</td>\n",
       "      <td>0.356469</td>\n",
       "      <td>0.333358</td>\n",
       "      <td>0.170571</td>\n",
       "      <td>0.336233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>1.097300</td>\n",
       "      <td>1.101490</td>\n",
       "      <td>0.325414</td>\n",
       "      <td>0.351953</td>\n",
       "      <td>0.283754</td>\n",
       "      <td>0.351075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>1.100400</td>\n",
       "      <td>1.098439</td>\n",
       "      <td>0.410744</td>\n",
       "      <td>0.344195</td>\n",
       "      <td>0.247111</td>\n",
       "      <td>0.340839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>1.090400</td>\n",
       "      <td>1.102054</td>\n",
       "      <td>0.240258</td>\n",
       "      <td>0.346806</td>\n",
       "      <td>0.246546</td>\n",
       "      <td>0.349539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>1.088400</td>\n",
       "      <td>1.101629</td>\n",
       "      <td>0.405307</td>\n",
       "      <td>0.355697</td>\n",
       "      <td>0.259890</td>\n",
       "      <td>0.357216</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>1.079800</td>\n",
       "      <td>1.086074</td>\n",
       "      <td>0.421017</td>\n",
       "      <td>0.403616</td>\n",
       "      <td>0.365924</td>\n",
       "      <td>0.403787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>1.046500</td>\n",
       "      <td>1.105456</td>\n",
       "      <td>0.432838</td>\n",
       "      <td>0.409104</td>\n",
       "      <td>0.386019</td>\n",
       "      <td>0.409928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>1.024300</td>\n",
       "      <td>1.074913</td>\n",
       "      <td>0.468523</td>\n",
       "      <td>0.437343</td>\n",
       "      <td>0.422828</td>\n",
       "      <td>0.437052</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.984500</td>\n",
       "      <td>1.071261</td>\n",
       "      <td>0.459032</td>\n",
       "      <td>0.455694</td>\n",
       "      <td>0.451951</td>\n",
       "      <td>0.455476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.978300</td>\n",
       "      <td>1.057697</td>\n",
       "      <td>0.472143</td>\n",
       "      <td>0.464368</td>\n",
       "      <td>0.458090</td>\n",
       "      <td>0.464176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.939800</td>\n",
       "      <td>1.024078</td>\n",
       "      <td>0.474954</td>\n",
       "      <td>0.475142</td>\n",
       "      <td>0.470859</td>\n",
       "      <td>0.474411</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.908000</td>\n",
       "      <td>1.015553</td>\n",
       "      <td>0.493254</td>\n",
       "      <td>0.492083</td>\n",
       "      <td>0.486958</td>\n",
       "      <td>0.491300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.899500</td>\n",
       "      <td>1.010137</td>\n",
       "      <td>0.509944</td>\n",
       "      <td>0.507113</td>\n",
       "      <td>0.504882</td>\n",
       "      <td>0.507165</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.882700</td>\n",
       "      <td>1.006877</td>\n",
       "      <td>0.509465</td>\n",
       "      <td>0.507841</td>\n",
       "      <td>0.499405</td>\n",
       "      <td>0.506653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.849200</td>\n",
       "      <td>1.007204</td>\n",
       "      <td>0.512360</td>\n",
       "      <td>0.511370</td>\n",
       "      <td>0.510916</td>\n",
       "      <td>0.511259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.860000</td>\n",
       "      <td>0.991405</td>\n",
       "      <td>0.527033</td>\n",
       "      <td>0.525314</td>\n",
       "      <td>0.523553</td>\n",
       "      <td>0.525077</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.854900</td>\n",
       "      <td>0.996846</td>\n",
       "      <td>0.526313</td>\n",
       "      <td>0.526874</td>\n",
       "      <td>0.523072</td>\n",
       "      <td>0.526100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.877600</td>\n",
       "      <td>0.961710</td>\n",
       "      <td>0.534377</td>\n",
       "      <td>0.533050</td>\n",
       "      <td>0.530294</td>\n",
       "      <td>0.532753</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.833100</td>\n",
       "      <td>1.000749</td>\n",
       "      <td>0.530769</td>\n",
       "      <td>0.532501</td>\n",
       "      <td>0.528673</td>\n",
       "      <td>0.531730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.790000</td>\n",
       "      <td>1.001935</td>\n",
       "      <td>0.541156</td>\n",
       "      <td>0.540374</td>\n",
       "      <td>0.535086</td>\n",
       "      <td>0.539406</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.803100</td>\n",
       "      <td>0.966206</td>\n",
       "      <td>0.553270</td>\n",
       "      <td>0.553915</td>\n",
       "      <td>0.550442</td>\n",
       "      <td>0.553224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.790700</td>\n",
       "      <td>1.007397</td>\n",
       "      <td>0.552591</td>\n",
       "      <td>0.547855</td>\n",
       "      <td>0.539988</td>\n",
       "      <td>0.547083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.782700</td>\n",
       "      <td>0.972360</td>\n",
       "      <td>0.560247</td>\n",
       "      <td>0.559690</td>\n",
       "      <td>0.555061</td>\n",
       "      <td>0.558854</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.771200</td>\n",
       "      <td>0.985450</td>\n",
       "      <td>0.559614</td>\n",
       "      <td>0.560160</td>\n",
       "      <td>0.559736</td>\n",
       "      <td>0.559877</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.790800</td>\n",
       "      <td>0.940414</td>\n",
       "      <td>0.562806</td>\n",
       "      <td>0.561347</td>\n",
       "      <td>0.561611</td>\n",
       "      <td>0.561412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.769800</td>\n",
       "      <td>0.961144</td>\n",
       "      <td>0.574850</td>\n",
       "      <td>0.570976</td>\n",
       "      <td>0.569730</td>\n",
       "      <td>0.571136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.771700</td>\n",
       "      <td>0.949668</td>\n",
       "      <td>0.575828</td>\n",
       "      <td>0.575901</td>\n",
       "      <td>0.573310</td>\n",
       "      <td>0.575230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.753400</td>\n",
       "      <td>0.962670</td>\n",
       "      <td>0.581772</td>\n",
       "      <td>0.582323</td>\n",
       "      <td>0.580961</td>\n",
       "      <td>0.581883</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.750400</td>\n",
       "      <td>0.933733</td>\n",
       "      <td>0.582377</td>\n",
       "      <td>0.580613</td>\n",
       "      <td>0.577926</td>\n",
       "      <td>0.580348</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.756500</td>\n",
       "      <td>0.926525</td>\n",
       "      <td>0.592616</td>\n",
       "      <td>0.590760</td>\n",
       "      <td>0.589025</td>\n",
       "      <td>0.590583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.749900</td>\n",
       "      <td>0.934821</td>\n",
       "      <td>0.597239</td>\n",
       "      <td>0.597446</td>\n",
       "      <td>0.597286</td>\n",
       "      <td>0.597236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.748300</td>\n",
       "      <td>0.934959</td>\n",
       "      <td>0.595444</td>\n",
       "      <td>0.595953</td>\n",
       "      <td>0.595544</td>\n",
       "      <td>0.595701</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.784700</td>\n",
       "      <td>0.910240</td>\n",
       "      <td>0.597845</td>\n",
       "      <td>0.597269</td>\n",
       "      <td>0.597330</td>\n",
       "      <td>0.597236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.743000</td>\n",
       "      <td>0.924299</td>\n",
       "      <td>0.590544</td>\n",
       "      <td>0.589004</td>\n",
       "      <td>0.589043</td>\n",
       "      <td>0.589048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.753900</td>\n",
       "      <td>0.927779</td>\n",
       "      <td>0.594146</td>\n",
       "      <td>0.587061</td>\n",
       "      <td>0.586783</td>\n",
       "      <td>0.587513</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.742800</td>\n",
       "      <td>0.919476</td>\n",
       "      <td>0.598632</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.598389</td>\n",
       "      <td>0.598260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.768000</td>\n",
       "      <td>0.904323</td>\n",
       "      <td>0.596177</td>\n",
       "      <td>0.593029</td>\n",
       "      <td>0.592259</td>\n",
       "      <td>0.593142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.722900</td>\n",
       "      <td>0.910157</td>\n",
       "      <td>0.598922</td>\n",
       "      <td>0.598554</td>\n",
       "      <td>0.597227</td>\n",
       "      <td>0.598260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.709900</td>\n",
       "      <td>0.937090</td>\n",
       "      <td>0.601711</td>\n",
       "      <td>0.601498</td>\n",
       "      <td>0.600881</td>\n",
       "      <td>0.601331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.907592</td>\n",
       "      <td>0.602002</td>\n",
       "      <td>0.601918</td>\n",
       "      <td>0.601783</td>\n",
       "      <td>0.601842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6400</td>\n",
       "      <td>0.716500</td>\n",
       "      <td>0.913611</td>\n",
       "      <td>0.602345</td>\n",
       "      <td>0.601455</td>\n",
       "      <td>0.600753</td>\n",
       "      <td>0.601331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>0.707800</td>\n",
       "      <td>0.922591</td>\n",
       "      <td>0.599195</td>\n",
       "      <td>0.599613</td>\n",
       "      <td>0.598852</td>\n",
       "      <td>0.599284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6600</td>\n",
       "      <td>0.727700</td>\n",
       "      <td>0.917236</td>\n",
       "      <td>0.603368</td>\n",
       "      <td>0.603027</td>\n",
       "      <td>0.602471</td>\n",
       "      <td>0.602866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6700</td>\n",
       "      <td>0.716400</td>\n",
       "      <td>0.919325</td>\n",
       "      <td>0.607982</td>\n",
       "      <td>0.607782</td>\n",
       "      <td>0.606707</td>\n",
       "      <td>0.607472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6800</td>\n",
       "      <td>0.696900</td>\n",
       "      <td>0.924073</td>\n",
       "      <td>0.604797</td>\n",
       "      <td>0.605226</td>\n",
       "      <td>0.604585</td>\n",
       "      <td>0.604913</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6900</td>\n",
       "      <td>0.723000</td>\n",
       "      <td>0.913512</td>\n",
       "      <td>0.607953</td>\n",
       "      <td>0.607204</td>\n",
       "      <td>0.606053</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>0.726300</td>\n",
       "      <td>0.904799</td>\n",
       "      <td>0.607437</td>\n",
       "      <td>0.607183</td>\n",
       "      <td>0.606405</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7100</td>\n",
       "      <td>0.690400</td>\n",
       "      <td>0.910915</td>\n",
       "      <td>0.610242</td>\n",
       "      <td>0.609361</td>\n",
       "      <td>0.607813</td>\n",
       "      <td>0.609007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7200</td>\n",
       "      <td>0.713000</td>\n",
       "      <td>0.906952</td>\n",
       "      <td>0.609058</td>\n",
       "      <td>0.609233</td>\n",
       "      <td>0.608667</td>\n",
       "      <td>0.609007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7300</td>\n",
       "      <td>0.701100</td>\n",
       "      <td>0.906767</td>\n",
       "      <td>0.607002</td>\n",
       "      <td>0.607075</td>\n",
       "      <td>0.606935</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7400</td>\n",
       "      <td>0.721200</td>\n",
       "      <td>0.907522</td>\n",
       "      <td>0.606442</td>\n",
       "      <td>0.606689</td>\n",
       "      <td>0.606068</td>\n",
       "      <td>0.606448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7500</td>\n",
       "      <td>0.675200</td>\n",
       "      <td>0.907358</td>\n",
       "      <td>0.608911</td>\n",
       "      <td>0.608194</td>\n",
       "      <td>0.607265</td>\n",
       "      <td>0.607984</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7600</td>\n",
       "      <td>0.746300</td>\n",
       "      <td>0.903351</td>\n",
       "      <td>0.608200</td>\n",
       "      <td>0.607193</td>\n",
       "      <td>0.606105</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7700</td>\n",
       "      <td>0.726900</td>\n",
       "      <td>0.903585</td>\n",
       "      <td>0.608557</td>\n",
       "      <td>0.607720</td>\n",
       "      <td>0.606504</td>\n",
       "      <td>0.607472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7800</td>\n",
       "      <td>0.685200</td>\n",
       "      <td>0.904943</td>\n",
       "      <td>0.606794</td>\n",
       "      <td>0.606209</td>\n",
       "      <td>0.605011</td>\n",
       "      <td>0.605937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7900</td>\n",
       "      <td>0.716300</td>\n",
       "      <td>0.903770</td>\n",
       "      <td>0.606582</td>\n",
       "      <td>0.606128</td>\n",
       "      <td>0.605394</td>\n",
       "      <td>0.605937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8000</td>\n",
       "      <td>0.695400</td>\n",
       "      <td>0.904205</td>\n",
       "      <td>0.608059</td>\n",
       "      <td>0.607642</td>\n",
       "      <td>0.607111</td>\n",
       "      <td>0.607472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8100</td>\n",
       "      <td>0.700900</td>\n",
       "      <td>0.904932</td>\n",
       "      <td>0.606936</td>\n",
       "      <td>0.606633</td>\n",
       "      <td>0.606029</td>\n",
       "      <td>0.606448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8200</td>\n",
       "      <td>0.707600</td>\n",
       "      <td>0.904926</td>\n",
       "      <td>0.607461</td>\n",
       "      <td>0.607143</td>\n",
       "      <td>0.606569</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8300</td>\n",
       "      <td>0.716900</td>\n",
       "      <td>0.904935</td>\n",
       "      <td>0.607461</td>\n",
       "      <td>0.607143</td>\n",
       "      <td>0.606569</td>\n",
       "      <td>0.606960</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/guangyu/anaconda3/envs/MD/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=8356, training_loss=0.8914895836048117, metrics={'train_runtime': 1881.2191, 'train_samples_per_second': 71.039, 'train_steps_per_second': 4.442, 'total_flos': 62512865722368.0, 'train_loss': 0.8914895836048117, 'epoch': 4.0})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "82b8833b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 2000/2000 [00:00<00:00, 20983.85 examples/s]\n",
      "100%|██████████| 500/500 [00:49<00:00, 10.14it/s]\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import pickle\n",
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "# Define preprocessing function with truncation\n",
    "def preprocessing_function(examples):\n",
    "    # Adjust 'text' to your input column name\n",
    "    return tokenizer(\n",
    "        examples['input'],\n",
    "        padding=False,  # Padding handled by collator\n",
    "        truncation=True,\n",
    "        max_length=512  # Set to your model's max length\n",
    "    )\n",
    "\n",
    "# Apply preprocessing with truncation\n",
    "col_to_delete=['instruction', 'input', 'output', 'answer']\n",
    "tokenized_train_data1 = train_data.select(range(2000)).map(\n",
    "    preprocessing_function,\n",
    "    batched=True,\n",
    "    remove_columns=col_to_delete\n",
    ")\n",
    "\n",
    "# Set data collator with explicit max_length\n",
    "data_collator = DataCollatorWithPadding(\n",
    "    tokenizer=tokenizer,\n",
    "    padding=\"max_length\",\n",
    "    max_length=512  # Match tokenizer's max_length\n",
    ")\n",
    "\n",
    "# Create DataLoader\n",
    "dataloader = DataLoader(\n",
    "    tokenized_train_data1,\n",
    "    batch_size=4,\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=False\n",
    ")\n",
    "\n",
    "# Proceed with model evaluation and feature extraction\n",
    "model.eval()\n",
    "model.cuda()\n",
    "\n",
    "X_list = []\n",
    "Z_layer_outputs = [[] for _ in range(model.config.num_hidden_layers)]\n",
    "Y_list = []\n",
    "\n",
    "for batch in tqdm(dataloader):\n",
    "    input_ids = batch['input_ids'].cuda()\n",
    "    attention_mask = batch['attention_mask'].cuda()\n",
    "    labels = batch['labels']\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        hidden_states = outputs.hidden_states\n",
    "        \n",
    "        # Store embeddings (mean pooled)\n",
    "        X_list.append(hidden_states[0].mean(dim=1).cpu())\n",
    "        \n",
    "        # Store layer outputs (mean pooled)\n",
    "        for i, layer_out in enumerate(hidden_states[1:]):\n",
    "            Z_layer_outputs[i].append(layer_out.mean(dim=1).cpu())\n",
    "        \n",
    "        Y_list.append(labels)\n",
    "\n",
    "# Stack all tensors\n",
    "X_tensor = torch.cat(X_list, dim=0)\n",
    "Z_tensors = [torch.cat(layer, dim=0) for layer in Z_layer_outputs]\n",
    "Z_tensor = torch.stack(Z_tensors, dim=0)\n",
    "Y_tensor = torch.cat(Y_list, dim=0)\n",
    "\n",
    "# Save dataset\n",
    "dataset_dict = {'X': X_tensor, 'Z': Z_tensor, 'Y': Y_tensor}\n",
    "with open('full_dataset_hws_mlm.pkl', 'wb') as f:\n",
    "    pickle.dump(dataset_dict, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MD",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
