{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "48a3f438",
   "metadata": {},
   "source": [
    "# Qwen - test on kaggle\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a6bb274d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 18,489,344 || all params: 1,562,228,224 || trainable%: 1.1835\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3478561/2897814771.py:326: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n",
      "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30762' max='30762' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30762/30762 13:36:12, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>2.285600</td>\n",
       "      <td>1.825324</td>\n",
       "      <td>0.446024</td>\n",
       "      <td>0.786593</td>\n",
       "      <td>0.716327</td>\n",
       "      <td>0.737668</td>\n",
       "      <td>0.785675</td>\n",
       "      <td>0.446024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.907400</td>\n",
       "      <td>0.707887</td>\n",
       "      <td>0.805657</td>\n",
       "      <td>0.915813</td>\n",
       "      <td>0.939688</td>\n",
       "      <td>0.913756</td>\n",
       "      <td>0.887750</td>\n",
       "      <td>0.805657</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>1.691000</td>\n",
       "      <td>0.994345</td>\n",
       "      <td>0.732305</td>\n",
       "      <td>0.898439</td>\n",
       "      <td>0.934435</td>\n",
       "      <td>0.868797</td>\n",
       "      <td>0.852489</td>\n",
       "      <td>0.732305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>1.841300</td>\n",
       "      <td>0.575598</td>\n",
       "      <td>0.810358</td>\n",
       "      <td>0.938439</td>\n",
       "      <td>0.960808</td>\n",
       "      <td>0.923453</td>\n",
       "      <td>0.888411</td>\n",
       "      <td>0.810358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>1.590400</td>\n",
       "      <td>0.645246</td>\n",
       "      <td>0.806575</td>\n",
       "      <td>0.915813</td>\n",
       "      <td>0.950340</td>\n",
       "      <td>0.906850</td>\n",
       "      <td>0.889734</td>\n",
       "      <td>0.806575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>1.558200</td>\n",
       "      <td>0.553937</td>\n",
       "      <td>0.843269</td>\n",
       "      <td>0.938512</td>\n",
       "      <td>0.958457</td>\n",
       "      <td>0.928007</td>\n",
       "      <td>0.910854</td>\n",
       "      <td>0.843269</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>1.576500</td>\n",
       "      <td>0.514555</td>\n",
       "      <td>0.848191</td>\n",
       "      <td>0.939394</td>\n",
       "      <td>0.962645</td>\n",
       "      <td>0.932158</td>\n",
       "      <td>0.917833</td>\n",
       "      <td>0.848191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>1.515400</td>\n",
       "      <td>0.513039</td>\n",
       "      <td>0.850101</td>\n",
       "      <td>0.941965</td>\n",
       "      <td>0.958457</td>\n",
       "      <td>0.934068</td>\n",
       "      <td>0.912103</td>\n",
       "      <td>0.850101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>1.642300</td>\n",
       "      <td>0.489400</td>\n",
       "      <td>0.853186</td>\n",
       "      <td>0.942736</td>\n",
       "      <td>0.962314</td>\n",
       "      <td>0.932672</td>\n",
       "      <td>0.920037</td>\n",
       "      <td>0.853186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>1.473600</td>\n",
       "      <td>0.496077</td>\n",
       "      <td>0.851570</td>\n",
       "      <td>0.943691</td>\n",
       "      <td>0.950303</td>\n",
       "      <td>0.933701</td>\n",
       "      <td>0.919229</td>\n",
       "      <td>0.851570</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>1.490400</td>\n",
       "      <td>0.517262</td>\n",
       "      <td>0.843269</td>\n",
       "      <td>0.937153</td>\n",
       "      <td>0.948981</td>\n",
       "      <td>0.932782</td>\n",
       "      <td>0.909715</td>\n",
       "      <td>0.843269</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>1.465700</td>\n",
       "      <td>0.521090</td>\n",
       "      <td>0.841726</td>\n",
       "      <td>0.935721</td>\n",
       "      <td>0.940312</td>\n",
       "      <td>0.937043</td>\n",
       "      <td>0.908577</td>\n",
       "      <td>0.841726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>1.597500</td>\n",
       "      <td>0.494333</td>\n",
       "      <td>0.853554</td>\n",
       "      <td>0.938806</td>\n",
       "      <td>0.957870</td>\n",
       "      <td>0.937447</td>\n",
       "      <td>0.920073</td>\n",
       "      <td>0.853554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>1.620000</td>\n",
       "      <td>0.473490</td>\n",
       "      <td>0.863067</td>\n",
       "      <td>0.938329</td>\n",
       "      <td>0.956327</td>\n",
       "      <td>0.941892</td>\n",
       "      <td>0.923453</td>\n",
       "      <td>0.863067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7500</td>\n",
       "      <td>1.493700</td>\n",
       "      <td>0.469216</td>\n",
       "      <td>0.857925</td>\n",
       "      <td>0.948136</td>\n",
       "      <td>0.964665</td>\n",
       "      <td>0.939541</td>\n",
       "      <td>0.919449</td>\n",
       "      <td>0.857925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8000</td>\n",
       "      <td>1.337200</td>\n",
       "      <td>0.472936</td>\n",
       "      <td>0.860496</td>\n",
       "      <td>0.945785</td>\n",
       "      <td>0.963967</td>\n",
       "      <td>0.941781</td>\n",
       "      <td>0.919045</td>\n",
       "      <td>0.860496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8500</td>\n",
       "      <td>1.411000</td>\n",
       "      <td>0.479280</td>\n",
       "      <td>0.855647</td>\n",
       "      <td>0.943177</td>\n",
       "      <td>0.962424</td>\n",
       "      <td>0.936015</td>\n",
       "      <td>0.921469</td>\n",
       "      <td>0.855647</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9000</td>\n",
       "      <td>1.398500</td>\n",
       "      <td>0.479797</td>\n",
       "      <td>0.860496</td>\n",
       "      <td>0.945124</td>\n",
       "      <td>0.962534</td>\n",
       "      <td>0.939835</td>\n",
       "      <td>0.916988</td>\n",
       "      <td>0.860496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9500</td>\n",
       "      <td>1.536600</td>\n",
       "      <td>0.462370</td>\n",
       "      <td>0.860790</td>\n",
       "      <td>0.930579</td>\n",
       "      <td>0.965032</td>\n",
       "      <td>0.941414</td>\n",
       "      <td>0.924628</td>\n",
       "      <td>0.860790</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10000</td>\n",
       "      <td>1.386100</td>\n",
       "      <td>0.402998</td>\n",
       "      <td>0.878641</td>\n",
       "      <td>0.952066</td>\n",
       "      <td>0.969550</td>\n",
       "      <td>0.947365</td>\n",
       "      <td>0.932305</td>\n",
       "      <td>0.878641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10500</td>\n",
       "      <td>1.197800</td>\n",
       "      <td>0.435109</td>\n",
       "      <td>0.873352</td>\n",
       "      <td>0.951185</td>\n",
       "      <td>0.962828</td>\n",
       "      <td>0.946410</td>\n",
       "      <td>0.931019</td>\n",
       "      <td>0.873352</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11000</td>\n",
       "      <td>1.121600</td>\n",
       "      <td>0.414403</td>\n",
       "      <td>0.875335</td>\n",
       "      <td>0.947805</td>\n",
       "      <td>0.969036</td>\n",
       "      <td>0.945234</td>\n",
       "      <td>0.931864</td>\n",
       "      <td>0.875335</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11500</td>\n",
       "      <td>1.282000</td>\n",
       "      <td>0.402505</td>\n",
       "      <td>0.878678</td>\n",
       "      <td>0.951589</td>\n",
       "      <td>0.960845</td>\n",
       "      <td>0.945455</td>\n",
       "      <td>0.934619</td>\n",
       "      <td>0.878678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12000</td>\n",
       "      <td>1.059600</td>\n",
       "      <td>0.387114</td>\n",
       "      <td>0.883820</td>\n",
       "      <td>0.953682</td>\n",
       "      <td>0.967787</td>\n",
       "      <td>0.950854</td>\n",
       "      <td>0.936272</td>\n",
       "      <td>0.883820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12500</td>\n",
       "      <td>1.267500</td>\n",
       "      <td>0.390446</td>\n",
       "      <td>0.884114</td>\n",
       "      <td>0.952764</td>\n",
       "      <td>0.968081</td>\n",
       "      <td>0.951221</td>\n",
       "      <td>0.933811</td>\n",
       "      <td>0.884114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13000</td>\n",
       "      <td>1.062000</td>\n",
       "      <td>0.385322</td>\n",
       "      <td>0.885583</td>\n",
       "      <td>0.954968</td>\n",
       "      <td>0.960110</td>\n",
       "      <td>0.950376</td>\n",
       "      <td>0.937557</td>\n",
       "      <td>0.885583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13500</td>\n",
       "      <td>1.205200</td>\n",
       "      <td>0.357577</td>\n",
       "      <td>0.894325</td>\n",
       "      <td>0.957098</td>\n",
       "      <td>0.969366</td>\n",
       "      <td>0.954931</td>\n",
       "      <td>0.939688</td>\n",
       "      <td>0.894325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14000</td>\n",
       "      <td>1.177000</td>\n",
       "      <td>0.356205</td>\n",
       "      <td>0.892966</td>\n",
       "      <td>0.957062</td>\n",
       "      <td>0.971423</td>\n",
       "      <td>0.955188</td>\n",
       "      <td>0.938659</td>\n",
       "      <td>0.892966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14500</td>\n",
       "      <td>0.950500</td>\n",
       "      <td>0.347345</td>\n",
       "      <td>0.896088</td>\n",
       "      <td>0.959633</td>\n",
       "      <td>0.969183</td>\n",
       "      <td>0.955629</td>\n",
       "      <td>0.941157</td>\n",
       "      <td>0.896088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15000</td>\n",
       "      <td>1.218100</td>\n",
       "      <td>0.362268</td>\n",
       "      <td>0.891423</td>\n",
       "      <td>0.948907</td>\n",
       "      <td>0.968191</td>\n",
       "      <td>0.955409</td>\n",
       "      <td>0.939284</td>\n",
       "      <td>0.891423</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15500</td>\n",
       "      <td>1.081900</td>\n",
       "      <td>0.346338</td>\n",
       "      <td>0.894839</td>\n",
       "      <td>0.960882</td>\n",
       "      <td>0.973701</td>\n",
       "      <td>0.953756</td>\n",
       "      <td>0.937557</td>\n",
       "      <td>0.894839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16000</td>\n",
       "      <td>1.137200</td>\n",
       "      <td>0.337021</td>\n",
       "      <td>0.898439</td>\n",
       "      <td>0.958347</td>\n",
       "      <td>0.975464</td>\n",
       "      <td>0.953095</td>\n",
       "      <td>0.944720</td>\n",
       "      <td>0.898439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16500</td>\n",
       "      <td>1.219700</td>\n",
       "      <td>0.314497</td>\n",
       "      <td>0.905381</td>\n",
       "      <td>0.965069</td>\n",
       "      <td>0.976566</td>\n",
       "      <td>0.958384</td>\n",
       "      <td>0.946483</td>\n",
       "      <td>0.905381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17000</td>\n",
       "      <td>1.276000</td>\n",
       "      <td>0.324319</td>\n",
       "      <td>0.901451</td>\n",
       "      <td>0.960918</td>\n",
       "      <td>0.973407</td>\n",
       "      <td>0.957172</td>\n",
       "      <td>0.943581</td>\n",
       "      <td>0.901451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17500</td>\n",
       "      <td>1.026400</td>\n",
       "      <td>0.323578</td>\n",
       "      <td>0.903067</td>\n",
       "      <td>0.962241</td>\n",
       "      <td>0.972635</td>\n",
       "      <td>0.956547</td>\n",
       "      <td>0.947181</td>\n",
       "      <td>0.903067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18000</td>\n",
       "      <td>0.923800</td>\n",
       "      <td>0.299480</td>\n",
       "      <td>0.911405</td>\n",
       "      <td>0.962975</td>\n",
       "      <td>0.977300</td>\n",
       "      <td>0.964371</td>\n",
       "      <td>0.949532</td>\n",
       "      <td>0.911405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18500</td>\n",
       "      <td>0.864600</td>\n",
       "      <td>0.289958</td>\n",
       "      <td>0.913352</td>\n",
       "      <td>0.963820</td>\n",
       "      <td>0.978659</td>\n",
       "      <td>0.964408</td>\n",
       "      <td>0.951074</td>\n",
       "      <td>0.913352</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19000</td>\n",
       "      <td>1.161700</td>\n",
       "      <td>0.286744</td>\n",
       "      <td>0.915225</td>\n",
       "      <td>0.966097</td>\n",
       "      <td>0.978659</td>\n",
       "      <td>0.964555</td>\n",
       "      <td>0.952948</td>\n",
       "      <td>0.915225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19500</td>\n",
       "      <td>1.023300</td>\n",
       "      <td>0.281383</td>\n",
       "      <td>0.915592</td>\n",
       "      <td>0.965987</td>\n",
       "      <td>0.977521</td>\n",
       "      <td>0.965179</td>\n",
       "      <td>0.952580</td>\n",
       "      <td>0.915592</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20000</td>\n",
       "      <td>0.892800</td>\n",
       "      <td>0.272399</td>\n",
       "      <td>0.919155</td>\n",
       "      <td>0.969036</td>\n",
       "      <td>0.978770</td>\n",
       "      <td>0.968522</td>\n",
       "      <td>0.953976</td>\n",
       "      <td>0.919155</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20500</td>\n",
       "      <td>1.032100</td>\n",
       "      <td>0.264035</td>\n",
       "      <td>0.921616</td>\n",
       "      <td>0.970064</td>\n",
       "      <td>0.978586</td>\n",
       "      <td>0.967787</td>\n",
       "      <td>0.953903</td>\n",
       "      <td>0.921616</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21000</td>\n",
       "      <td>0.493900</td>\n",
       "      <td>0.227423</td>\n",
       "      <td>0.931350</td>\n",
       "      <td>0.972819</td>\n",
       "      <td>0.981084</td>\n",
       "      <td>0.973113</td>\n",
       "      <td>0.960220</td>\n",
       "      <td>0.931350</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21500</td>\n",
       "      <td>0.530000</td>\n",
       "      <td>0.210242</td>\n",
       "      <td>0.936676</td>\n",
       "      <td>0.975537</td>\n",
       "      <td>0.984022</td>\n",
       "      <td>0.975243</td>\n",
       "      <td>0.962718</td>\n",
       "      <td>0.936676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22000</td>\n",
       "      <td>0.460900</td>\n",
       "      <td>0.198061</td>\n",
       "      <td>0.940569</td>\n",
       "      <td>0.977778</td>\n",
       "      <td>0.984169</td>\n",
       "      <td>0.976676</td>\n",
       "      <td>0.966134</td>\n",
       "      <td>0.940569</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22500</td>\n",
       "      <td>0.592200</td>\n",
       "      <td>0.194537</td>\n",
       "      <td>0.942185</td>\n",
       "      <td>0.977741</td>\n",
       "      <td>0.984353</td>\n",
       "      <td>0.976676</td>\n",
       "      <td>0.966354</td>\n",
       "      <td>0.942185</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23000</td>\n",
       "      <td>0.687900</td>\n",
       "      <td>0.185359</td>\n",
       "      <td>0.945087</td>\n",
       "      <td>0.979284</td>\n",
       "      <td>0.984720</td>\n",
       "      <td>0.978733</td>\n",
       "      <td>0.968118</td>\n",
       "      <td>0.945087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23500</td>\n",
       "      <td>0.479600</td>\n",
       "      <td>0.180393</td>\n",
       "      <td>0.947916</td>\n",
       "      <td>0.980716</td>\n",
       "      <td>0.985748</td>\n",
       "      <td>0.979871</td>\n",
       "      <td>0.969366</td>\n",
       "      <td>0.947916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24000</td>\n",
       "      <td>0.667700</td>\n",
       "      <td>0.168317</td>\n",
       "      <td>0.950487</td>\n",
       "      <td>0.981671</td>\n",
       "      <td>0.986152</td>\n",
       "      <td>0.981047</td>\n",
       "      <td>0.971938</td>\n",
       "      <td>0.950487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24500</td>\n",
       "      <td>0.516800</td>\n",
       "      <td>0.162847</td>\n",
       "      <td>0.952727</td>\n",
       "      <td>0.981524</td>\n",
       "      <td>0.986777</td>\n",
       "      <td>0.981708</td>\n",
       "      <td>0.974031</td>\n",
       "      <td>0.952727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25000</td>\n",
       "      <td>0.589600</td>\n",
       "      <td>0.156793</td>\n",
       "      <td>0.954784</td>\n",
       "      <td>0.983140</td>\n",
       "      <td>0.988209</td>\n",
       "      <td>0.982112</td>\n",
       "      <td>0.974178</td>\n",
       "      <td>0.954784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25500</td>\n",
       "      <td>0.530100</td>\n",
       "      <td>0.150895</td>\n",
       "      <td>0.956657</td>\n",
       "      <td>0.983655</td>\n",
       "      <td>0.988246</td>\n",
       "      <td>0.983067</td>\n",
       "      <td>0.975023</td>\n",
       "      <td>0.956657</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26000</td>\n",
       "      <td>0.692400</td>\n",
       "      <td>0.142637</td>\n",
       "      <td>0.959633</td>\n",
       "      <td>0.984279</td>\n",
       "      <td>0.988797</td>\n",
       "      <td>0.984830</td>\n",
       "      <td>0.976860</td>\n",
       "      <td>0.959633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26500</td>\n",
       "      <td>0.487300</td>\n",
       "      <td>0.137935</td>\n",
       "      <td>0.960918</td>\n",
       "      <td>0.984536</td>\n",
       "      <td>0.988907</td>\n",
       "      <td>0.985308</td>\n",
       "      <td>0.977521</td>\n",
       "      <td>0.960918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27000</td>\n",
       "      <td>0.478500</td>\n",
       "      <td>0.133302</td>\n",
       "      <td>0.962241</td>\n",
       "      <td>0.985455</td>\n",
       "      <td>0.989642</td>\n",
       "      <td>0.985638</td>\n",
       "      <td>0.978035</td>\n",
       "      <td>0.962241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27500</td>\n",
       "      <td>0.457600</td>\n",
       "      <td>0.130090</td>\n",
       "      <td>0.963526</td>\n",
       "      <td>0.985932</td>\n",
       "      <td>0.989936</td>\n",
       "      <td>0.985932</td>\n",
       "      <td>0.979100</td>\n",
       "      <td>0.963526</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28000</td>\n",
       "      <td>0.481200</td>\n",
       "      <td>0.126854</td>\n",
       "      <td>0.964371</td>\n",
       "      <td>0.986299</td>\n",
       "      <td>0.990340</td>\n",
       "      <td>0.986410</td>\n",
       "      <td>0.979651</td>\n",
       "      <td>0.964371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28500</td>\n",
       "      <td>0.301700</td>\n",
       "      <td>0.123515</td>\n",
       "      <td>0.965399</td>\n",
       "      <td>0.986483</td>\n",
       "      <td>0.990707</td>\n",
       "      <td>0.987034</td>\n",
       "      <td>0.980092</td>\n",
       "      <td>0.965399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29000</td>\n",
       "      <td>0.515500</td>\n",
       "      <td>0.122049</td>\n",
       "      <td>0.965767</td>\n",
       "      <td>0.986924</td>\n",
       "      <td>0.990781</td>\n",
       "      <td>0.987034</td>\n",
       "      <td>0.980239</td>\n",
       "      <td>0.965767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29500</td>\n",
       "      <td>0.498500</td>\n",
       "      <td>0.120924</td>\n",
       "      <td>0.966134</td>\n",
       "      <td>0.986740</td>\n",
       "      <td>0.990744</td>\n",
       "      <td>0.987144</td>\n",
       "      <td>0.980606</td>\n",
       "      <td>0.966134</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30000</td>\n",
       "      <td>0.405100</td>\n",
       "      <td>0.120360</td>\n",
       "      <td>0.966061</td>\n",
       "      <td>0.986850</td>\n",
       "      <td>0.990744</td>\n",
       "      <td>0.987291</td>\n",
       "      <td>0.980349</td>\n",
       "      <td>0.966061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30500</td>\n",
       "      <td>0.628900</td>\n",
       "      <td>0.120258</td>\n",
       "      <td>0.966171</td>\n",
       "      <td>0.986961</td>\n",
       "      <td>0.990817</td>\n",
       "      <td>0.987181</td>\n",
       "      <td>0.980496</td>\n",
       "      <td>0.966171</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Eval on chosen dataset ===\n",
      "eval_loss: 0.1203\n",
      "eval_acc_16: 0.9662\n",
      "eval_acc_ei: 0.9870\n",
      "eval_acc_ns: 0.9908\n",
      "eval_acc_tf: 0.9872\n",
      "eval_acc_jp: 0.9805\n",
      "eval_acc_4D: 0.9662\n",
      "eval_runtime: 532.5159\n",
      "eval_samples_per_second: 51.1250\n",
      "eval_steps_per_second: 6.3920\n",
      "epoch: 3.0000\n",
      "\n",
      "=== Test accuracy on chosen dataset: 1.0000\n",
      "样例原标签: INFJ | 预测: INFJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "训练 + 评测（LoRA / 4bit / 单卡）\n",
    "- 基座: Qwen/Qwen2.5-1.5B-Instruct\n",
    "- 训练集: 两数据集合并 (mbti_sample_with_all_views.json + mbti_sample_with_all_views_pandora.json)\n",
    "- Eval/Test: 只在指定的数据集上评测 (默认 Pandora)\n",
    "- 输出: 指标 + 混淆矩阵 + ROC(micro/macro) + LoRA 适配器权重\n",
    "\"\"\"\n",
    "import os, json, random\n",
    "from typing import Dict, Any, List\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from peft import (\n",
    "    LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel\n",
    ")\n",
    "\n",
    "# ================== 配置 ==================\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "DATA_A       = \"mbti_sample_with_all_views.json\"          # 旧数据\n",
    "DATA_B       = \"mbti_sample_with_all_views_pandora.json\"  # Pandora\n",
    "EVAL_ON      = \"A\"  # 只在哪个集上做 eval/test: \"A\" or \"B\"\n",
    "OUTPUT_DIR   = \"qwen-test-on-pandora\"                      # 输出目录（含 lora）\n",
    "RESUME_ADAPTER_DIR = None  # 若已有 LoRA 断点，可填入目录；否则置为 None\n",
    "\n",
    "MAX_LEN      = 320\n",
    "USE_4BIT     = True\n",
    "SEED         = 42\n",
    "NUM_LABELS   = 16\n",
    "\n",
    "# LoRA 超参（可按需微调）\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "# Qwen2.5 常用目标模块\n",
    "LORA_TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "# 训练超参（按你的显存情况调整）\n",
    "BATCH_SIZE_PER_DEVICE_TRAIN = 8\n",
    "BATCH_SIZE_PER_DEVICE_EVAL  = 8\n",
    "GR_ACCUM_STEPS              = 1\n",
    "EPOCHS                      = 3\n",
    "LR                          = 2e-4\n",
    "WARMUP_RATIO                = 0.05\n",
    "LOGGING_STEPS               = 20\n",
    "SAVE_STEPS                  = 500\n",
    "EVAL_STEPS                  = 500\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# 与训练一致的多视角 token 预算\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ================== 工具函数 ==================\n",
    "def mbti_to_4d(m: str):\n",
    "    m = m.upper()\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    preds, labels = (eval_pred if isinstance(eval_pred, tuple)\n",
    "                     else (eval_pred.predictions, eval_pred.label_ids))\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    preds = np.asarray(preds); labels = np.asarray(labels)\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, tag=\"eval\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix ({tag})\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"{tag}_confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC：跳过评测集中没有正样本的类\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr, tpr, roc_auc = {}, {}, {}\n",
    "    valid = []\n",
    "    for i in range(len(class_names)):\n",
    "        if Y_true_bin[:, i].sum() == 0:\n",
    "            continue\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        valid.append(i)\n",
    "    if len(valid) >= 2:\n",
    "        fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(\n",
    "            Y_true_bin[:, valid].ravel(), y_prob[:, valid].ravel()\n",
    "        )\n",
    "        roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "        all_fpr = np.unique(np.concatenate([fpr[i] for i in valid]))\n",
    "        mean_tpr = np.zeros_like(all_fpr)\n",
    "        for i in valid:\n",
    "            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "        mean_tpr /= len(valid)\n",
    "        fpr[\"macro\"] = all_fpr; tpr[\"macro\"] = mean_tpr\n",
    "        roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "        fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "        ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                    label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                    label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "        ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "        ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "        ax_roc.set_title(f\"Multiclass ROC ({tag})\")\n",
    "        ax_roc.legend(loc=\"lower right\")\n",
    "        fig_roc.tight_layout()\n",
    "        fig_roc.savefig(os.path.join(out_dir, f\"{tag}_roc_micro_macro.png\"))\n",
    "        plt.close(fig_roc)\n",
    "\n",
    "# ================== 主流程 ==================\n",
    "def main():\n",
    "    # 环境 & 种子\n",
    "    os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "    os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        BASE_MODEL, use_fast=True, trust_remote_code=True, **HF_KW\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 量化\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 分类头：num_labels=16\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True, **HF_KW)\n",
    "    base_cfg.num_labels = NUM_LABELS\n",
    "\n",
    "    # 基座\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL,\n",
    "        config=base_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        quantization_config=quant_cfg,\n",
    "        trust_remote_code=True,\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "\n",
    "    # ========= LoRA：新训或续训 =========\n",
    "    if RESUME_ADAPTER_DIR:\n",
    "        # 从已训练的 LoRA 继续\n",
    "        model = PeftModel.from_pretrained(base, RESUME_ADAPTER_DIR, is_trainable=True)\n",
    "    else:\n",
    "        # 新建 LoRA\n",
    "        base = prepare_model_for_kbit_training(base)  # 4bit 可训练准备\n",
    "        lora_cfg = LoraConfig(\n",
    "            r=LORA_R,\n",
    "            lora_alpha=LORA_ALPHA,\n",
    "            target_modules=LORA_TARGET_MODULES,\n",
    "            lora_dropout=LORA_DROPOUT,\n",
    "            bias=\"none\",\n",
    "            task_type=\"SEQ_CLS\",\n",
    "        )\n",
    "        model = get_peft_model(base, lora_cfg)\n",
    "\n",
    "    model.config.use_cache = False\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "    # ========= 数据 =========\n",
    "    rows_A = load_rows(DATA_A)\n",
    "    rows_B = load_rows(DATA_B)\n",
    "\n",
    "    # 训练集 = A ∪ B\n",
    "    train_rows: List[Dict[str, Any]] = rows_A + rows_B\n",
    "    random.Random(SEED).shuffle(train_rows)\n",
    "\n",
    "    # eval/test 只用指定一个数据集（默认 B=Pandora）\n",
    "    if EVAL_ON.upper() == \"A\":\n",
    "        eval_rows = rows_A\n",
    "        eval_tag  = \"A_eval\"\n",
    "    else:\n",
    "        eval_rows = rows_B\n",
    "        eval_tag  = \"B_eval\"\n",
    "\n",
    "    # （可选）从 eval_rows 再划一个 test 子集；这里简单按 80/20 切\n",
    "    cut = int(0.8 * len(eval_rows)) if len(eval_rows) > 5 else len(eval_rows)\n",
    "    test_rows = eval_rows[cut:]\n",
    "    eval_rows = eval_rows[:cut] if cut > 0 else eval_rows\n",
    "\n",
    "    # 构建数据集\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    eval_ds  = MBTIDataset(eval_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN) if test_rows else None\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # ========= 训练参数 =========\n",
    "    from transformers import TrainingArguments\n",
    "\n",
    "    common_kwargs = dict(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BATCH_SIZE_PER_DEVICE_TRAIN,\n",
    "        per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE_EVAL,\n",
    "        gradient_accumulation_steps=GR_ACCUM_STEPS,\n",
    "        learning_rate=LR,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        logging_steps=LOGGING_STEPS,\n",
    "        eval_steps=EVAL_STEPS,\n",
    "        save_steps=SAVE_STEPS,\n",
    "        save_total_limit=2,\n",
    "        lr_scheduler_type=\"cosine\",\n",
    "        report_to=\"none\",\n",
    "        fp16=False, bf16=False,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_16\",\n",
    "        greater_is_better=True,\n",
    "        # 这两个在新老版本都存在，显式写上更稳\n",
    "        logging_strategy=\"steps\",\n",
    "        save_strategy=\"steps\",\n",
    "    )\n",
    "\n",
    "    # 依次尝试新/旧/远古命名，保证不同版本都能跑\n",
    "    try:\n",
    "        args = TrainingArguments(eval_strategy=\"steps\", **common_kwargs)\n",
    "    except TypeError:\n",
    "        try:\n",
    "            args = TrainingArguments(evaluation_strategy=\"steps\", **common_kwargs)\n",
    "        except TypeError:\n",
    "            # 超老版本（3.x）兜底\n",
    "            args = TrainingArguments(evaluate_during_training=True, **common_kwargs)\n",
    "\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=eval_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # ========= 训练 =========\n",
    "    trainer.train()\n",
    "\n",
    "    # 保存 LoRA（适配器）\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    try:\n",
    "        model.save_pretrained(os.path.join(OUTPUT_DIR, \"lora_adapter\"))\n",
    "    except Exception as e:\n",
    "        print(\"Save adapter failed:\", e)\n",
    "\n",
    "    # ========= Eval（在指定集）=========\n",
    "    eval_output = trainer.predict(eval_ds)\n",
    "    logits = eval_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = eval_output.label_ids\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}\")\n",
    "\n",
    "    metrics = trainer.evaluate(eval_dataset=eval_ds)\n",
    "    print(\"\\n=== Eval on chosen dataset ===\")\n",
    "    for k, v in metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    # ========= Test（同一数据集的 hold-out 部分）=========\n",
    "    if test_ds and len(test_ds) > 0:\n",
    "        test_output = trainer.predict(test_ds)\n",
    "        logits = test_output.predictions\n",
    "        if isinstance(logits, (list, tuple)):\n",
    "            logits = logits[0]\n",
    "        probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "        y_true = test_output.label_ids\n",
    "        plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}_test\")\n",
    "\n",
    "        # 简单整体准确率\n",
    "        pred_ids = probs.argmax(-1)\n",
    "        acc = float((pred_ids == y_true).mean())\n",
    "        print(f\"\\n=== Test accuracy on chosen dataset: {acc:.4f}\")\n",
    "\n",
    "    # ========= 示例推理 =========\n",
    "    model.eval()\n",
    "    sample = (rows_B[0] if EVAL_ON.upper()==\"B\" else rows_A[0]) if (rows_A and rows_B) else (train_rows[0])\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        out = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(out, dim=-1))\n",
    "        print(\"样例原标签:\", sample[\"type\"], \"| 预测:\", MBTI_16[pred_id])\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "50846c41",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 18,489,344 || all params: 1,562,228,224 || trainable%: 1.1835\n",
      "[Split-A] Saved: train=27226, test=6806\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_253522/2329019196.py:374: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='28212' max='28212' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [28212/28212 12:37:33, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>2.174600</td>\n",
       "      <td>0.981938</td>\n",
       "      <td>0.741387</td>\n",
       "      <td>0.910710</td>\n",
       "      <td>0.931940</td>\n",
       "      <td>0.873063</td>\n",
       "      <td>0.859693</td>\n",
       "      <td>0.741387</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>2.042000</td>\n",
       "      <td>0.596887</td>\n",
       "      <td>0.828326</td>\n",
       "      <td>0.914787</td>\n",
       "      <td>0.931352</td>\n",
       "      <td>0.929002</td>\n",
       "      <td>0.910527</td>\n",
       "      <td>0.828326</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>1.971000</td>\n",
       "      <td>0.632225</td>\n",
       "      <td>0.829685</td>\n",
       "      <td>0.935319</td>\n",
       "      <td>0.923602</td>\n",
       "      <td>0.929259</td>\n",
       "      <td>0.904907</td>\n",
       "      <td>0.829685</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>1.968600</td>\n",
       "      <td>0.546939</td>\n",
       "      <td>0.835415</td>\n",
       "      <td>0.933666</td>\n",
       "      <td>0.942702</td>\n",
       "      <td>0.937449</td>\n",
       "      <td>0.905201</td>\n",
       "      <td>0.835415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>1.742300</td>\n",
       "      <td>0.485310</td>\n",
       "      <td>0.858334</td>\n",
       "      <td>0.944281</td>\n",
       "      <td>0.944795</td>\n",
       "      <td>0.939470</td>\n",
       "      <td>0.922611</td>\n",
       "      <td>0.858334</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>1.541500</td>\n",
       "      <td>0.587856</td>\n",
       "      <td>0.830566</td>\n",
       "      <td>0.912033</td>\n",
       "      <td>0.949130</td>\n",
       "      <td>0.948799</td>\n",
       "      <td>0.922905</td>\n",
       "      <td>0.830566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>1.734900</td>\n",
       "      <td>0.455448</td>\n",
       "      <td>0.865643</td>\n",
       "      <td>0.944281</td>\n",
       "      <td>0.952876</td>\n",
       "      <td>0.944612</td>\n",
       "      <td>0.924704</td>\n",
       "      <td>0.865643</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>1.581800</td>\n",
       "      <td>0.439966</td>\n",
       "      <td>0.867112</td>\n",
       "      <td>0.944648</td>\n",
       "      <td>0.958606</td>\n",
       "      <td>0.945383</td>\n",
       "      <td>0.923235</td>\n",
       "      <td>0.867112</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>1.725200</td>\n",
       "      <td>0.427858</td>\n",
       "      <td>0.870675</td>\n",
       "      <td>0.953794</td>\n",
       "      <td>0.968413</td>\n",
       "      <td>0.947366</td>\n",
       "      <td>0.924668</td>\n",
       "      <td>0.870675</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>1.741700</td>\n",
       "      <td>0.410553</td>\n",
       "      <td>0.879307</td>\n",
       "      <td>0.952545</td>\n",
       "      <td>0.969992</td>\n",
       "      <td>0.946705</td>\n",
       "      <td>0.929846</td>\n",
       "      <td>0.879307</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>1.396200</td>\n",
       "      <td>0.415666</td>\n",
       "      <td>0.877801</td>\n",
       "      <td>0.940057</td>\n",
       "      <td>0.949166</td>\n",
       "      <td>0.950195</td>\n",
       "      <td>0.934327</td>\n",
       "      <td>0.877801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>1.603500</td>\n",
       "      <td>0.409567</td>\n",
       "      <td>0.883163</td>\n",
       "      <td>0.954235</td>\n",
       "      <td>0.962205</td>\n",
       "      <td>0.946485</td>\n",
       "      <td>0.932564</td>\n",
       "      <td>0.883163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>1.500700</td>\n",
       "      <td>0.394274</td>\n",
       "      <td>0.886248</td>\n",
       "      <td>0.954529</td>\n",
       "      <td>0.957100</td>\n",
       "      <td>0.952141</td>\n",
       "      <td>0.934989</td>\n",
       "      <td>0.886248</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>1.621800</td>\n",
       "      <td>0.362826</td>\n",
       "      <td>0.893705</td>\n",
       "      <td>0.952178</td>\n",
       "      <td>0.969808</td>\n",
       "      <td>0.955924</td>\n",
       "      <td>0.942261</td>\n",
       "      <td>0.893705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7500</td>\n",
       "      <td>1.563200</td>\n",
       "      <td>0.362498</td>\n",
       "      <td>0.891942</td>\n",
       "      <td>0.956622</td>\n",
       "      <td>0.968266</td>\n",
       "      <td>0.951958</td>\n",
       "      <td>0.938625</td>\n",
       "      <td>0.891942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8000</td>\n",
       "      <td>1.487800</td>\n",
       "      <td>0.359640</td>\n",
       "      <td>0.894329</td>\n",
       "      <td>0.955484</td>\n",
       "      <td>0.969625</td>\n",
       "      <td>0.956365</td>\n",
       "      <td>0.942775</td>\n",
       "      <td>0.894329</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8500</td>\n",
       "      <td>1.417700</td>\n",
       "      <td>0.382565</td>\n",
       "      <td>0.886285</td>\n",
       "      <td>0.944208</td>\n",
       "      <td>0.963564</td>\n",
       "      <td>0.953941</td>\n",
       "      <td>0.938037</td>\n",
       "      <td>0.886285</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9000</td>\n",
       "      <td>1.593800</td>\n",
       "      <td>0.363487</td>\n",
       "      <td>0.891280</td>\n",
       "      <td>0.958202</td>\n",
       "      <td>0.968670</td>\n",
       "      <td>0.950599</td>\n",
       "      <td>0.942041</td>\n",
       "      <td>0.891280</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9500</td>\n",
       "      <td>1.277900</td>\n",
       "      <td>0.367975</td>\n",
       "      <td>0.893374</td>\n",
       "      <td>0.951480</td>\n",
       "      <td>0.966503</td>\n",
       "      <td>0.959083</td>\n",
       "      <td>0.940314</td>\n",
       "      <td>0.893374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10000</td>\n",
       "      <td>1.342700</td>\n",
       "      <td>0.319569</td>\n",
       "      <td>0.903181</td>\n",
       "      <td>0.959597</td>\n",
       "      <td>0.974216</td>\n",
       "      <td>0.960038</td>\n",
       "      <td>0.946558</td>\n",
       "      <td>0.903181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10500</td>\n",
       "      <td>1.325600</td>\n",
       "      <td>0.306704</td>\n",
       "      <td>0.907258</td>\n",
       "      <td>0.964776</td>\n",
       "      <td>0.975832</td>\n",
       "      <td>0.964336</td>\n",
       "      <td>0.947734</td>\n",
       "      <td>0.907258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11000</td>\n",
       "      <td>1.367900</td>\n",
       "      <td>0.311648</td>\n",
       "      <td>0.907037</td>\n",
       "      <td>0.962977</td>\n",
       "      <td>0.971975</td>\n",
       "      <td>0.960699</td>\n",
       "      <td>0.948028</td>\n",
       "      <td>0.907037</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11500</td>\n",
       "      <td>1.418300</td>\n",
       "      <td>0.301925</td>\n",
       "      <td>0.911114</td>\n",
       "      <td>0.961581</td>\n",
       "      <td>0.975832</td>\n",
       "      <td>0.963711</td>\n",
       "      <td>0.951297</td>\n",
       "      <td>0.911114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12000</td>\n",
       "      <td>1.272700</td>\n",
       "      <td>0.288189</td>\n",
       "      <td>0.912400</td>\n",
       "      <td>0.965915</td>\n",
       "      <td>0.972159</td>\n",
       "      <td>0.963454</td>\n",
       "      <td>0.951554</td>\n",
       "      <td>0.912400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12500</td>\n",
       "      <td>1.323700</td>\n",
       "      <td>0.280599</td>\n",
       "      <td>0.916697</td>\n",
       "      <td>0.964336</td>\n",
       "      <td>0.973702</td>\n",
       "      <td>0.966209</td>\n",
       "      <td>0.954529</td>\n",
       "      <td>0.916697</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13000</td>\n",
       "      <td>1.057000</td>\n",
       "      <td>0.288260</td>\n",
       "      <td>0.915191</td>\n",
       "      <td>0.961360</td>\n",
       "      <td>0.977558</td>\n",
       "      <td>0.965988</td>\n",
       "      <td>0.953133</td>\n",
       "      <td>0.915191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13500</td>\n",
       "      <td>1.190500</td>\n",
       "      <td>0.272291</td>\n",
       "      <td>0.918020</td>\n",
       "      <td>0.967311</td>\n",
       "      <td>0.977632</td>\n",
       "      <td>0.966576</td>\n",
       "      <td>0.954529</td>\n",
       "      <td>0.918020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14000</td>\n",
       "      <td>1.215700</td>\n",
       "      <td>0.277464</td>\n",
       "      <td>0.916624</td>\n",
       "      <td>0.962756</td>\n",
       "      <td>0.978183</td>\n",
       "      <td>0.967531</td>\n",
       "      <td>0.954859</td>\n",
       "      <td>0.916624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14500</td>\n",
       "      <td>1.289000</td>\n",
       "      <td>0.263562</td>\n",
       "      <td>0.922721</td>\n",
       "      <td>0.968670</td>\n",
       "      <td>0.977595</td>\n",
       "      <td>0.967494</td>\n",
       "      <td>0.957724</td>\n",
       "      <td>0.922721</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15000</td>\n",
       "      <td>1.150700</td>\n",
       "      <td>0.255239</td>\n",
       "      <td>0.922207</td>\n",
       "      <td>0.971094</td>\n",
       "      <td>0.980460</td>\n",
       "      <td>0.966539</td>\n",
       "      <td>0.956659</td>\n",
       "      <td>0.922207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15500</td>\n",
       "      <td>1.251800</td>\n",
       "      <td>0.247099</td>\n",
       "      <td>0.926541</td>\n",
       "      <td>0.970543</td>\n",
       "      <td>0.979505</td>\n",
       "      <td>0.970176</td>\n",
       "      <td>0.959267</td>\n",
       "      <td>0.926541</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16000</td>\n",
       "      <td>1.141300</td>\n",
       "      <td>0.252704</td>\n",
       "      <td>0.924117</td>\n",
       "      <td>0.968853</td>\n",
       "      <td>0.979542</td>\n",
       "      <td>0.965878</td>\n",
       "      <td>0.958642</td>\n",
       "      <td>0.924117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16500</td>\n",
       "      <td>1.129500</td>\n",
       "      <td>0.237749</td>\n",
       "      <td>0.926174</td>\n",
       "      <td>0.972306</td>\n",
       "      <td>0.982223</td>\n",
       "      <td>0.969037</td>\n",
       "      <td>0.959010</td>\n",
       "      <td>0.926174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17000</td>\n",
       "      <td>1.011900</td>\n",
       "      <td>0.243705</td>\n",
       "      <td>0.927055</td>\n",
       "      <td>0.967788</td>\n",
       "      <td>0.979321</td>\n",
       "      <td>0.972930</td>\n",
       "      <td>0.959157</td>\n",
       "      <td>0.927055</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17500</td>\n",
       "      <td>1.172900</td>\n",
       "      <td>0.227873</td>\n",
       "      <td>0.930838</td>\n",
       "      <td>0.971277</td>\n",
       "      <td>0.981452</td>\n",
       "      <td>0.972416</td>\n",
       "      <td>0.960663</td>\n",
       "      <td>0.930838</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18000</td>\n",
       "      <td>0.849800</td>\n",
       "      <td>0.209897</td>\n",
       "      <td>0.935650</td>\n",
       "      <td>0.976713</td>\n",
       "      <td>0.984170</td>\n",
       "      <td>0.974803</td>\n",
       "      <td>0.963785</td>\n",
       "      <td>0.935650</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18500</td>\n",
       "      <td>1.007900</td>\n",
       "      <td>0.209617</td>\n",
       "      <td>0.936605</td>\n",
       "      <td>0.975722</td>\n",
       "      <td>0.982590</td>\n",
       "      <td>0.972894</td>\n",
       "      <td>0.964372</td>\n",
       "      <td>0.936605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19000</td>\n",
       "      <td>0.502600</td>\n",
       "      <td>0.180701</td>\n",
       "      <td>0.945089</td>\n",
       "      <td>0.980166</td>\n",
       "      <td>0.986079</td>\n",
       "      <td>0.978293</td>\n",
       "      <td>0.968890</td>\n",
       "      <td>0.945089</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19500</td>\n",
       "      <td>0.579200</td>\n",
       "      <td>0.168246</td>\n",
       "      <td>0.948762</td>\n",
       "      <td>0.979395</td>\n",
       "      <td>0.986777</td>\n",
       "      <td>0.979652</td>\n",
       "      <td>0.970690</td>\n",
       "      <td>0.948762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20000</td>\n",
       "      <td>0.402200</td>\n",
       "      <td>0.161481</td>\n",
       "      <td>0.952325</td>\n",
       "      <td>0.982957</td>\n",
       "      <td>0.988540</td>\n",
       "      <td>0.981598</td>\n",
       "      <td>0.972783</td>\n",
       "      <td>0.952325</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20500</td>\n",
       "      <td>0.634800</td>\n",
       "      <td>0.149950</td>\n",
       "      <td>0.955300</td>\n",
       "      <td>0.982407</td>\n",
       "      <td>0.988357</td>\n",
       "      <td>0.982002</td>\n",
       "      <td>0.975134</td>\n",
       "      <td>0.955300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21000</td>\n",
       "      <td>0.561600</td>\n",
       "      <td>0.143034</td>\n",
       "      <td>0.957173</td>\n",
       "      <td>0.983325</td>\n",
       "      <td>0.988834</td>\n",
       "      <td>0.983215</td>\n",
       "      <td>0.976016</td>\n",
       "      <td>0.957173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21500</td>\n",
       "      <td>0.443100</td>\n",
       "      <td>0.137321</td>\n",
       "      <td>0.959304</td>\n",
       "      <td>0.984390</td>\n",
       "      <td>0.989789</td>\n",
       "      <td>0.984427</td>\n",
       "      <td>0.977264</td>\n",
       "      <td>0.959304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22000</td>\n",
       "      <td>0.532000</td>\n",
       "      <td>0.133058</td>\n",
       "      <td>0.961507</td>\n",
       "      <td>0.985418</td>\n",
       "      <td>0.990414</td>\n",
       "      <td>0.983766</td>\n",
       "      <td>0.978550</td>\n",
       "      <td>0.961507</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22500</td>\n",
       "      <td>0.359300</td>\n",
       "      <td>0.122615</td>\n",
       "      <td>0.963748</td>\n",
       "      <td>0.986373</td>\n",
       "      <td>0.990230</td>\n",
       "      <td>0.985933</td>\n",
       "      <td>0.978954</td>\n",
       "      <td>0.963748</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23000</td>\n",
       "      <td>0.329600</td>\n",
       "      <td>0.117573</td>\n",
       "      <td>0.965952</td>\n",
       "      <td>0.987365</td>\n",
       "      <td>0.991111</td>\n",
       "      <td>0.986373</td>\n",
       "      <td>0.979872</td>\n",
       "      <td>0.965952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23500</td>\n",
       "      <td>0.625200</td>\n",
       "      <td>0.110675</td>\n",
       "      <td>0.967715</td>\n",
       "      <td>0.987989</td>\n",
       "      <td>0.991699</td>\n",
       "      <td>0.987108</td>\n",
       "      <td>0.980937</td>\n",
       "      <td>0.967715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24000</td>\n",
       "      <td>0.488300</td>\n",
       "      <td>0.106646</td>\n",
       "      <td>0.969037</td>\n",
       "      <td>0.988797</td>\n",
       "      <td>0.992177</td>\n",
       "      <td>0.987806</td>\n",
       "      <td>0.981672</td>\n",
       "      <td>0.969037</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24500</td>\n",
       "      <td>0.391700</td>\n",
       "      <td>0.101932</td>\n",
       "      <td>0.970139</td>\n",
       "      <td>0.989128</td>\n",
       "      <td>0.992764</td>\n",
       "      <td>0.988173</td>\n",
       "      <td>0.982113</td>\n",
       "      <td>0.970139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25000</td>\n",
       "      <td>0.415400</td>\n",
       "      <td>0.098072</td>\n",
       "      <td>0.971131</td>\n",
       "      <td>0.989275</td>\n",
       "      <td>0.992948</td>\n",
       "      <td>0.988393</td>\n",
       "      <td>0.982957</td>\n",
       "      <td>0.971131</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25500</td>\n",
       "      <td>0.457100</td>\n",
       "      <td>0.094781</td>\n",
       "      <td>0.972673</td>\n",
       "      <td>0.990046</td>\n",
       "      <td>0.993278</td>\n",
       "      <td>0.989091</td>\n",
       "      <td>0.983692</td>\n",
       "      <td>0.972673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26000</td>\n",
       "      <td>0.451900</td>\n",
       "      <td>0.091811</td>\n",
       "      <td>0.973665</td>\n",
       "      <td>0.990671</td>\n",
       "      <td>0.993352</td>\n",
       "      <td>0.989569</td>\n",
       "      <td>0.984463</td>\n",
       "      <td>0.973665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26500</td>\n",
       "      <td>0.484100</td>\n",
       "      <td>0.089931</td>\n",
       "      <td>0.973959</td>\n",
       "      <td>0.990560</td>\n",
       "      <td>0.993462</td>\n",
       "      <td>0.989569</td>\n",
       "      <td>0.984610</td>\n",
       "      <td>0.973959</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27000</td>\n",
       "      <td>0.512900</td>\n",
       "      <td>0.089034</td>\n",
       "      <td>0.974253</td>\n",
       "      <td>0.990781</td>\n",
       "      <td>0.993756</td>\n",
       "      <td>0.989642</td>\n",
       "      <td>0.984831</td>\n",
       "      <td>0.974253</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27500</td>\n",
       "      <td>0.644400</td>\n",
       "      <td>0.088684</td>\n",
       "      <td>0.974106</td>\n",
       "      <td>0.990781</td>\n",
       "      <td>0.993683</td>\n",
       "      <td>0.989532</td>\n",
       "      <td>0.984941</td>\n",
       "      <td>0.974106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28000</td>\n",
       "      <td>0.512300</td>\n",
       "      <td>0.088663</td>\n",
       "      <td>0.974253</td>\n",
       "      <td>0.990781</td>\n",
       "      <td>0.993683</td>\n",
       "      <td>0.989679</td>\n",
       "      <td>0.984941</td>\n",
       "      <td>0.974253</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/home/hli962/.virtualenvs/server/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:929: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Eval on chosen dataset ===\n",
      "eval_loss: 0.0890\n",
      "eval_acc_16: 0.9743\n",
      "eval_acc_ei: 0.9908\n",
      "eval_acc_ns: 0.9938\n",
      "eval_acc_tf: 0.9896\n",
      "eval_acc_jp: 0.9848\n",
      "eval_acc_4D: 0.9743\n",
      "eval_runtime: 542.2758\n",
      "eval_samples_per_second: 50.2070\n",
      "eval_steps_per_second: 6.2770\n",
      "epoch: 3.0000\n",
      "\n",
      "=== Test accuracy on chosen dataset: 0.8892\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'rows_B_all' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 437\u001b[0m\n\u001b[1;32m    434\u001b[0m         \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m样例原标签:\u001b[39m\u001b[38;5;124m\"\u001b[39m, sample[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m| 预测:\u001b[39m\u001b[38;5;124m\"\u001b[39m, MBTI_16[pred_id])\n\u001b[1;32m    436\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 437\u001b[0m     \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[1], line 427\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m    425\u001b[0m \u001b[38;5;66;03m# ========= 示例推理 =========\u001b[39;00m\n\u001b[1;32m    426\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m--> 427\u001b[0m sample \u001b[38;5;241m=\u001b[39m (rows_B_all[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mrows_B_all\u001b[49m \u001b[38;5;28;01melse\u001b[39;00m rows_A[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m    428\u001b[0m text \u001b[38;5;241m=\u001b[39m build_input(sample, tokenizer)\n\u001b[1;32m    429\u001b[0m batch \u001b[38;5;241m=\u001b[39m tokenizer(text, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, max_length\u001b[38;5;241m=\u001b[39mMAX_LEN)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'rows_B_all' is not defined"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "训练 + 评测（LoRA / 4bit / 单卡）\n",
    "- 基座: Qwen/Qwen2.5-1.5B-Instruct\n",
    "- 训练集: A ∪ B_train（分层8:2后B的80%）\n",
    "- Eval/Test: eval 在 B_train，test 在 B_test（B的20%）\n",
    "- 输出: 指标 + 混淆矩阵 + ROC(micro/macro) + LoRA 适配器权重\n",
    "\"\"\"\n",
    "import os, json, random\n",
    "from typing import Dict, Any, List\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from peft import (\n",
    "    LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel\n",
    ")\n",
    "\n",
    "# ================== 配置 ==================\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "DATA_A       = \"mbti_sample_with_all_views.json\"          # 旧数据（A）\n",
    "DATA_B       = \"mbti_sample_with_all_views_pandora.json\"  # Pandora（B）\n",
    "EVAL_ON      = \"A\"  # 只在哪个集上做 eval/test: \"A\" or \"B\"\n",
    "OUTPUT_DIR   = \"qwen-test-on-pandora_new\"                     # 输出目录（含 LoRA）\n",
    "RESUME_ADAPTER_DIR = None  # 若已有 LoRA 断点，可填入目录；否则置为 None\n",
    "\n",
    "\n",
    "MAX_LEN      = 320\n",
    "USE_4BIT     = True\n",
    "SEED         = 42\n",
    "NUM_LABELS   = 16\n",
    "\n",
    "# LoRA 超参（可按需微调）\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "# Qwen2.5 常用目标模块\n",
    "LORA_TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "# 训练超参（按你的显存情况调整）\n",
    "BATCH_SIZE_PER_DEVICE_TRAIN = 8\n",
    "BATCH_SIZE_PER_DEVICE_EVAL  = 8\n",
    "GR_ACCUM_STEPS              = 1\n",
    "EPOCHS                      = 3\n",
    "LR                          = 2e-4\n",
    "WARMUP_RATIO                = 0.05\n",
    "LOGGING_STEPS               = 20\n",
    "SAVE_STEPS                  = 500\n",
    "EVAL_STEPS                  = 500\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# 与训练一致的多视角 token 预算\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ================== 工具函数 ==================\n",
    "def mbti_to_4d(m: str):\n",
    "    m = m.upper()\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "def stratified_split_by_type(rows, ratio=0.8, seed=42):\n",
    "    \"\"\"按 16 类型分层切分 rows -> (train_part, test_part)\"\"\"\n",
    "    buckets = defaultdict(list)\n",
    "    for r in rows:\n",
    "        buckets[r[\"type\"]].append(r)\n",
    "\n",
    "    rng = random.Random(seed)\n",
    "    train, test = [], []\n",
    "    for t, lst in buckets.items():\n",
    "        rng.shuffle(lst)\n",
    "        n = len(lst)\n",
    "        if n <= 1:\n",
    "            train.extend(lst)              # 极小类：全进训练\n",
    "            continue\n",
    "        cut = int(round(n * ratio))\n",
    "        cut = min(max(1, cut), n - 1)      # 保证两边都有样本\n",
    "        train.extend(lst[:cut])\n",
    "        test.extend(lst[cut:])\n",
    "    return train, test\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    preds, labels = (eval_pred if isinstance(eval_pred, tuple)\n",
    "                     else (eval_pred.predictions, eval_pred.label_ids))\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    preds = np.asarray(preds); labels = np.asarray(labels)\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, tag=\"eval\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix ({tag})\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"{tag}_confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC：跳过评测集中没有正样本的类\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr, tpr, roc_auc = {}, {}, {}\n",
    "    valid = []\n",
    "    for i in range(len(class_names)):\n",
    "        if Y_true_bin[:, i].sum() == 0:\n",
    "            continue\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        valid.append(i)\n",
    "    if len(valid) >= 2:\n",
    "        fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(\n",
    "            Y_true_bin[:, valid].ravel(), y_prob[:, valid].ravel()\n",
    "        )\n",
    "        roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "        all_fpr = np.unique(np.concatenate([fpr[i] for i in valid]))\n",
    "        mean_tpr = np.zeros_like(all_fpr)\n",
    "        for i in valid:\n",
    "            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "        mean_tpr /= len(valid)\n",
    "        fpr[\"macro\"] = all_fpr; tpr[\"macro\"] = mean_tpr\n",
    "        roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "        fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "        ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                    label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                    label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "        ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "        ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "        ax_roc.set_title(f\"Multiclass ROC ({tag})\")\n",
    "        ax_roc.legend(loc=\"lower right\")\n",
    "        fig_roc.tight_layout()\n",
    "        fig_roc.savefig(os.path.join(out_dir, f\"{tag}_roc_micro_macro.png\"))\n",
    "        plt.close(fig_roc)\n",
    "\n",
    "# ================== 主流程 ==================\n",
    "def main():\n",
    "    # 环境 & 种子\n",
    "    os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "    os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        BASE_MODEL, use_fast=True, trust_remote_code=True, **HF_KW\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 量化\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 分类头：num_labels=16\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True, **HF_KW)\n",
    "    base_cfg.num_labels = NUM_LABELS\n",
    "\n",
    "    # 基座\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL,\n",
    "        config=base_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        quantization_config=quant_cfg,\n",
    "        trust_remote_code=True,\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "\n",
    "    # ========= LoRA：新训或续训 =========\n",
    "    if RESUME_ADAPTER_DIR:\n",
    "        model = PeftModel.from_pretrained(base, RESUME_ADAPTER_DIR, is_trainable=True)\n",
    "    else:\n",
    "        base = prepare_model_for_kbit_training(base)  # 4bit 可训练准备\n",
    "        lora_cfg = LoraConfig(\n",
    "            r=LORA_R,\n",
    "            lora_alpha=LORA_ALPHA,\n",
    "            target_modules=LORA_TARGET_MODULES,\n",
    "            lora_dropout=LORA_DROPOUT,\n",
    "            bias=\"none\",\n",
    "            task_type=\"SEQ_CLS\",\n",
    "        )\n",
    "        model = get_peft_model(base, lora_cfg)\n",
    "\n",
    "    model.config.use_cache = False\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "    # ========= 数据（分层切分 + 落盘，按 EVAL_ON 选择 A 或 B）=========\n",
    "    rows_A = load_rows(DATA_A)\n",
    "    rows_B = load_rows(DATA_B)\n",
    "\n",
    "    def stratified_split_by_type(rows, ratio=0.8, seed=SEED):\n",
    "        from collections import defaultdict\n",
    "        rng = random.Random(seed)\n",
    "        buckets = defaultdict(list)\n",
    "        for r in rows:\n",
    "            buckets[r[\"type\"]].append(r)\n",
    "        train, test = [], []\n",
    "        for t, lst in buckets.items():\n",
    "            rng.shuffle(lst)\n",
    "            n = len(lst)\n",
    "            if n <= 1:\n",
    "                train.extend(lst)\n",
    "                continue\n",
    "            cut = int(round(n * ratio))\n",
    "            cut = min(max(1, cut), n - 1)  # 保证两边都有样本\n",
    "            train.extend(lst[:cut])\n",
    "            test.extend(lst[cut:])\n",
    "        return train, test\n",
    "\n",
    "    if EVAL_ON.upper() == \"A\":\n",
    "        chosen_all   = rows_A     # 只对 A 做分层切分并用于 eval/test\n",
    "        other_all    = rows_B     # B 整集全部进训练\n",
    "        split_tag    = \"A\"\n",
    "    else:\n",
    "        chosen_all   = rows_B     # 只对 B 做分层切分并用于 eval/test\n",
    "        other_all    = rows_A     # A 整集全部进训练\n",
    "        split_tag    = \"B\"\n",
    "\n",
    "    split_dir = os.path.join(OUTPUT_DIR, f\"splits_stratified_{split_tag}\")\n",
    "    os.makedirs(split_dir, exist_ok=True)\n",
    "    train_path = os.path.join(split_dir, f\"{split_tag}_train_80.json\")\n",
    "    test_path  = os.path.join(split_dir, f\"{split_tag}_test_20.json\")\n",
    "\n",
    "    if os.path.exists(train_path) and os.path.exists(test_path):\n",
    "        chosen_train = json.load(open(train_path, \"r\", encoding=\"utf-8\"))\n",
    "        chosen_test  = json.load(open(test_path,  \"r\", encoding=\"utf-8\"))\n",
    "        print(f\"[Split-{split_tag}] Loaded existing: train={len(chosen_train)}, test={len(chosen_test)}\")\n",
    "    else:\n",
    "        chosen_train, chosen_test = stratified_split_by_type(chosen_all, ratio=0.8, seed=SEED)\n",
    "        json.dump(chosen_train, open(train_path, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "        json.dump(chosen_test,  open(test_path,  \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "        print(f\"[Split-{split_tag}] Saved: train={len(chosen_train)}, test={len(chosen_test)}\")\n",
    "\n",
    "    # 训练集 = 未选中整集 + 选中集的 80%\n",
    "    train_rows: List[Dict[str, Any]] = other_all + chosen_train\n",
    "    random.Random(SEED).shuffle(train_rows)\n",
    "\n",
    "    # Eval/Test 只用选中的那个集（保持与你预期一致）\n",
    "    eval_rows = chosen_train\n",
    "    test_rows = chosen_test\n",
    "    eval_tag  = f\"{split_tag}_eval_stratified\"\n",
    "\n",
    "    # 构建数据集\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    eval_ds  = MBTIDataset(eval_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN) if test_rows else None\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # ========= 训练参数 =========\n",
    "    common_kwargs = dict(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BATCH_SIZE_PER_DEVICE_TRAIN,\n",
    "        per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE_EVAL,\n",
    "        gradient_accumulation_steps=GR_ACCUM_STEPS,\n",
    "        learning_rate=LR,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        logging_steps=LOGGING_STEPS,\n",
    "        eval_steps=EVAL_STEPS,\n",
    "        save_steps=SAVE_STEPS,\n",
    "        save_total_limit=2,\n",
    "        lr_scheduler_type=\"cosine\",\n",
    "        report_to=\"none\",\n",
    "        fp16=False, bf16=False,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_16\",\n",
    "        greater_is_better=True,\n",
    "        logging_strategy=\"steps\",\n",
    "        save_strategy=\"steps\",\n",
    "    )\n",
    "    try:\n",
    "        args = TrainingArguments(eval_strategy=\"steps\", **common_kwargs)\n",
    "    except TypeError:\n",
    "        try:\n",
    "            args = TrainingArguments(evaluation_strategy=\"steps\", **common_kwargs)\n",
    "        except TypeError:\n",
    "            args = TrainingArguments(evaluate_during_training=True, **common_kwargs)\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=eval_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # ========= 训练 =========\n",
    "    trainer.train()\n",
    "\n",
    "    # 保存 LoRA（适配器）\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    try:\n",
    "        model.save_pretrained(os.path.join(OUTPUT_DIR, \"lora_adapter\"))\n",
    "    except Exception as e:\n",
    "        print(\"Save adapter failed:\", e)\n",
    "\n",
    "    # ========= Eval（B_train）=========\n",
    "    eval_output = trainer.predict(eval_ds)\n",
    "    logits = eval_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = eval_output.label_ids\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}\")\n",
    "\n",
    "    metrics = trainer.evaluate(eval_dataset=eval_ds)\n",
    "    print(\"\\n=== Eval on chosen dataset ===\")\n",
    "    for k, v in metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    # ========= Test（B_test）=========\n",
    "    if test_ds and len(test_ds) > 0:\n",
    "        test_output = trainer.predict(test_ds)\n",
    "        logits = test_output.predictions\n",
    "        if isinstance(logits, (list, tuple)):\n",
    "            logits = logits[0]\n",
    "        probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "        y_true = test_output.label_ids\n",
    "        plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}_test\")\n",
    "\n",
    "        pred_ids = probs.argmax(-1)\n",
    "        acc = float((pred_ids == y_true).mean())\n",
    "        print(f\"\\n=== Test accuracy on chosen dataset: {acc:.4f}\")\n",
    "\n",
    "    # ========= 示例推理 =========\n",
    "    model.eval()\n",
    "    sample = (rows_B_all[0] if rows_B_all else rows_A[0])\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        out = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(out, dim=-1))\n",
    "        print(\"样例原标签:\", sample[\"type\"], \"| 预测:\", MBTI_16[pred_id])\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "server",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
