{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcb6bc2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 已保存：train.json（27224 条）\n",
      "✅ 已保存：val.json（3404 条）\n",
      "✅ 已保存：test.json（3404 条）\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "将 mbti_sample_with_all_views.json 分层切分为 8:1:1 的 train/val/test\n",
    "随机种子固定为 42，确保可复现\n",
    "\"\"\"\n",
    "import json\n",
    "from pathlib import Path\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "def load_rows(path: Path):\n",
    "    with path.open(\"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(\"输入文件中没有合法样本（缺少 'type' 或不在 16 类里）。\")\n",
    "    return rows\n",
    "\n",
    "def save_json(rows, path: Path):\n",
    "    path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(rows, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "def main(\n",
    "    input_file=\"mbti_sample_with_all_views.json\",\n",
    "    outdir=\".\",\n",
    "    seed=42\n",
    "):\n",
    "    inp = Path(input_file)\n",
    "    out = Path(outdir)\n",
    "\n",
    "    rows = load_rows(inp)\n",
    "    y = [r[\"type\"] for r in rows]\n",
    "\n",
    "    # 先取 10% 作为 TEST（分层）\n",
    "    trainval_rows, test_rows = train_test_split(\n",
    "        rows, test_size=0.10, random_state=seed, stratify=y\n",
    "    )\n",
    "    # 再从 90% 里切 10% 作为 VAL（占总数 0.1）=> 0.1 / 0.9\n",
    "    y_trainval = [r[\"type\"] for r in trainval_rows]\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        trainval_rows, test_size=0.1111111111, random_state=seed, stratify=y_trainval\n",
    "    )\n",
    "\n",
    "    save_json(train_rows, out / \"train.json\")\n",
    "    save_json(val_rows,   out / \"val.json\")\n",
    "    save_json(test_rows,  out / \"test.json\")\n",
    "\n",
    "    print(f\"✅ 已保存：{out/'train.json'}（{len(train_rows)} 条）\")\n",
    "    print(f\"✅ 已保存：{out/'val.json'}（{len(val_rows)} 条）\")\n",
    "    print(f\"✅ 已保存：{out/'test.json'}（{len(test_rows)} 条）\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2eb74ba4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2661063/3024490461.py:261: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='13612' max='13612' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [13612/13612 2:56:46, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.386700</td>\n",
       "      <td>0.383969</td>\n",
       "      <td>0.895417</td>\n",
       "      <td>0.954465</td>\n",
       "      <td>0.977086</td>\n",
       "      <td>0.953878</td>\n",
       "      <td>0.943890</td>\n",
       "      <td>0.895417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.379200</td>\n",
       "      <td>0.354813</td>\n",
       "      <td>0.896298</td>\n",
       "      <td>0.956816</td>\n",
       "      <td>0.977086</td>\n",
       "      <td>0.953290</td>\n",
       "      <td>0.943302</td>\n",
       "      <td>0.896298</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.147200</td>\n",
       "      <td>0.580361</td>\n",
       "      <td>0.900411</td>\n",
       "      <td>0.962103</td>\n",
       "      <td>0.977086</td>\n",
       "      <td>0.959459</td>\n",
       "      <td>0.944477</td>\n",
       "      <td>0.900411</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.027100</td>\n",
       "      <td>0.882224</td>\n",
       "      <td>0.897767</td>\n",
       "      <td>0.956522</td>\n",
       "      <td>0.972973</td>\n",
       "      <td>0.958578</td>\n",
       "      <td>0.945358</td>\n",
       "      <td>0.897767</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: mbti_lora_qwen1.5b-split_kaggle_ckpt_new/confusion_matrix.png\n",
      "Saved: mbti_lora_qwen1.5b-split_kaggle_ckpt_new/roc_micro_macro.png\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Final Eval (on VAL) ===\n",
      "eval_loss: 0.5804\n",
      "eval_acc_16: 0.9004\n",
      "eval_acc_ei: 0.9621\n",
      "eval_acc_ns: 0.9771\n",
      "eval_acc_tf: 0.9595\n",
      "eval_acc_jp: 0.9445\n",
      "eval_acc_4D: 0.9004\n",
      "eval_runtime: 91.6887\n",
      "eval_samples_per_second: 37.1260\n",
      "eval_steps_per_second: 9.2810\n",
      "epoch: 4.0000\n",
      "Saved: mbti_lora_qwen1.5b-split_kaggle_ckpt_new/confusion_matrix_test.png\n",
      "Saved: mbti_lora_qwen1.5b-split_kaggle_ckpt_new/roc_micro_macro_test.png\n",
      "\n",
      "=== Final Test (held-out TEST) ===\n",
      "acc_16: 0.6589\n",
      "acc_ei: 0.8570  acc_ns: 0.9109  acc_tf: 0.8593  acc_jp: 0.8042  acc_4D: 0.6589\n",
      "\n",
      "✅ LoRA adapter saved to: mbti_lora_qwen1.5b-split_kaggle_ckpt_new\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INFP  | 预测: INFP\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "DeepSeek-R1-Distill-Qwen-1.5B + QLoRA(4bit) + LoRA\n",
    "读取 train/val/test.json 进行训练与评测（VAL & TEST）\n",
    "Transformers==4.55\n",
    "\"\"\"\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay,\n",
    "    roc_curve, auc\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 环境开关\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "# ============ 配置 ============\n",
    "MODEL_NAME   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 440\n",
    "BUDGET = {\"posts_cleaned\": 320, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 8\n",
    "BSZ_EVAL     = 4\n",
    "GRAD_ACCUM   = 1\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_lora_qwen1.5b-split_kaggle_ckpt_new\"\n",
    "\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "# ============ 基础函数 ============\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    if not isinstance(preds, np.ndarray): preds = np.asarray(preds)\n",
    "    if not isinstance(labels, np.ndarray): labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix{suffix}\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro){suffix}\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "def main():\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # 读取切分后的文件\n",
    "    train_rows = load_rows(\"train.json\")\n",
    "    val_rows   = load_rows(\"val.json\")\n",
    "    test_rows  = load_rows(\"test对应的原始数据.json\")\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)\n",
    "    try:\n",
    "        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n",
    "    except Exception:\n",
    "        pass\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "    model = model.to(\"cuda:0\")\n",
    "\n",
    "    def _noop_to(self, *args, **kwargs): return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False, fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        eval_accumulation_steps=12,\n",
    "        gradient_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # 训练\n",
    "    trainer.train()\n",
    "\n",
    "    # 验证集预测 & 作图\n",
    "    val_output = trainer.predict(val_ds)\n",
    "    val_logits = val_output.predictions[0] if isinstance(val_output.predictions, (list, tuple)) else val_output.predictions\n",
    "    val_probs  = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    val_y_true = val_output.label_ids\n",
    "    plot_confusion_and_roc(val_y_true, val_probs, MBTI_16, OUTPUT_DIR, suffix=\"\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    # 验证集指标\n",
    "    eval_metrics = trainer.evaluate(eval_dataset=val_ds)\n",
    "    print(\"\\n=== Final Eval (on VAL) ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try: print(f\"{k}: {float(v):.4f}\")\n",
    "        except: print(k, v)\n",
    "\n",
    "    # 测试集评测 & 作图（只报告）\n",
    "    test_output = trainer.predict(test_ds)\n",
    "    test_logits = test_output.predictions[0] if isinstance(test_output.predictions, (list, tuple)) else test_output.predictions\n",
    "    test_probs  = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    test_y_true = test_output.label_ids\n",
    "    plot_confusion_and_roc(test_y_true, test_probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro_test.png')}\")\n",
    "\n",
    "    # 测试集 4D 指标\n",
    "    test_pred_ids = test_logits.argmax(-1)\n",
    "    acc16 = float((test_pred_ids == test_y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in test_pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in test_y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(test_y_true)\n",
    "    print(\"\\n=== Final Test (held-out TEST) ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "\n",
    "    # 保存 LoRA 适配器\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 推理示例（从 TEST 取一个样本）\n",
    "    model.eval()\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9eaefeca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2670172/393536711.py:296: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='15316' max='15316' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [15316/15316 2:21:40, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.380500</td>\n",
       "      <td>0.423923</td>\n",
       "      <td>0.871328</td>\n",
       "      <td>0.952409</td>\n",
       "      <td>0.969154</td>\n",
       "      <td>0.945358</td>\n",
       "      <td>0.920094</td>\n",
       "      <td>0.871328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.356500</td>\n",
       "      <td>0.384398</td>\n",
       "      <td>0.888954</td>\n",
       "      <td>0.960047</td>\n",
       "      <td>0.975029</td>\n",
       "      <td>0.952409</td>\n",
       "      <td>0.933901</td>\n",
       "      <td>0.888954</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.120300</td>\n",
       "      <td>0.584234</td>\n",
       "      <td>0.882491</td>\n",
       "      <td>0.952115</td>\n",
       "      <td>0.970623</td>\n",
       "      <td>0.948884</td>\n",
       "      <td>0.932726</td>\n",
       "      <td>0.882491</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.015000</td>\n",
       "      <td>0.841067</td>\n",
       "      <td>0.883960</td>\n",
       "      <td>0.952996</td>\n",
       "      <td>0.970623</td>\n",
       "      <td>0.950646</td>\n",
       "      <td>0.932726</td>\n",
       "      <td>0.883960</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: mbti_lora_deepseek-1b_ckpt/confusion_matrix.png\n",
      "Saved: mbti_lora_deepseek-1b_ckpt/roc_micro_macro.png\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='851' max='851' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [851/851 01:05]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Final Eval ===\n",
      "eval_loss: 0.3844\n",
      "eval_acc_16: 0.8890\n",
      "eval_acc_ei: 0.9600\n",
      "eval_acc_ns: 0.9750\n",
      "eval_acc_tf: 0.9524\n",
      "eval_acc_jp: 0.9339\n",
      "eval_acc_4D: 0.8890\n",
      "eval_runtime: 65.9047\n",
      "eval_samples_per_second: 51.6500\n",
      "eval_steps_per_second: 12.9130\n",
      "epoch: 4.0000\n",
      "\n",
      "✅ LoRA adapter saved to: mbti_lora_deepseek-1b_ckpt\n",
      "\n",
      "原标签: ENTP  | 预测: ENTP\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Meta-Llama-3-8B-Instruct + QLoRA(4bit) + LoRA\n",
    "MBTI 16类分类（含4D严格准确率）\n",
    "Transformers==4.55\n",
    "GPU-only（绝不使用CPU/磁盘offload）\n",
    "\"\"\"\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")  # 服务器/无显示环境\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay,\n",
    "    roc_curve, auc\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "# 禁用 accelerate 混精，关 bnb 欢迎语\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 建议先关 flash-attn（环境确认后再打开）\n",
    "# try:\n",
    "#     from transformers import set_attn_implementation\n",
    "#     set_attn_implementation(\"flash_attention_2\")\n",
    "# except Exception:\n",
    "#     pass\n",
    "\n",
    "# ============ 配置 ============\n",
    "MODEL_NAME   = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
    "FILE_PATH    = \"mbti_sample_with_all_views.json\"\n",
    "\n",
    "# 24GB 显存友好（等效 tokens/step ≈ 3840）\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 8           # per-device\n",
    "BSZ_EVAL     = 4\n",
    "GRAD_ACCUM   = 1            # 320 * 2 * 6 = 3840\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_lora_deepseek-1b_ckpt\"\n",
    "\n",
    "# QLoRA & LoRA\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ============ MBTI 工具 ============\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    return [r for r in rows if r.get(\"type\") in MBTI2ID]\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "# ============ Dataset ============\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ============ 指标 ============\n",
    "def compute_metrics(eval_pred):\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    if not isinstance(preds, np.ndarray): preds = np.asarray(preds)\n",
    "    if not isinstance(labels, np.ndarray): labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir):\n",
    "    \"\"\"\n",
    "    y_true: shape (N,)\n",
    "    y_prob: shape (N, C) — softmax 后的概率\n",
    "    class_names: 长度 C 的类别名列表\n",
    "    \"\"\"\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    # ===== 混淆矩阵 =====\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(\"Confusion Matrix\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, \"confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ===== 多分类 ROC（micro/macro）=====\n",
    "    # 将 y_true 进行 one-vs-rest 二值化\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))  # (N, C)\n",
    "    fpr = dict(); tpr = dict(); roc_auc = dict()\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    # micro\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    # macro（各类 AUC 的算术平均）\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    # 只画 micro/macro（可读性更好；若需要每类曲线可再加）\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(\"Multiclass ROC (micro & macro)\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, \"roc_micro_macro.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ============ 主流程 ============\n",
    "def main():\n",
    "    torch.cuda.set_device(0)  # 显式选择 GPU\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        rows, test_size=0.1, random_state=SEED, stratify=[r[\"type\"] for r in rows]\n",
    "    )\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # QLoRA 4bit（GPU-only）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 模型加载到 GPU（禁止自动下放）\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},    # 强制整模在单卡\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "\n",
    "    # 禁止 resize 词表（避免未量化大embedding + 设备错放）\n",
    "    # model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    # LoRA（先准备再统一迁移到 GPU）\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)\n",
    "    try:\n",
    "        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n",
    "    except Exception:\n",
    "        pass\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "\n",
    "    # 统一放到 GPU（只指定设备，不改 dtype）\n",
    "    model = model.to(\"cuda:0\")\n",
    "\n",
    "    # （可选）此时再屏蔽 .to，防误触 dtype cast\n",
    "    def _noop_to(self, *args, **kwargs): return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    # 数据 & collator\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 训练参数（GPU-only 友好）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,      # 2\n",
    "        per_device_eval_batch_size=BSZ_EVAL,      # 4\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,   # 6\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False, fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        eval_accumulation_steps=12,\n",
    "        gradient_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()\n",
    "    # 训练结束后，拿验证集预测的 logits\n",
    "    pred_output = trainer.predict(val_ds)\n",
    "    # pred_output.predictions 形状通常是 (N, C)\n",
    "    logits = pred_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    # 概率：softmax\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = pred_output.label_ids\n",
    "\n",
    "    # 画图并保存\n",
    "    plot_confusion_and_roc(\n",
    "        y_true=y_true,\n",
    "        y_prob=probs,\n",
    "        class_names=MBTI_16,\n",
    "        out_dir=OUTPUT_DIR\n",
    "    )\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    eval_metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Final Eval ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 推理示例（确保同一设备）\n",
    "    model.eval()\n",
    "    sample = val_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "46a148ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2833968/2215430006.py:303: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='21600' max='21600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [21600/21600 3:20:18, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.135700</td>\n",
       "      <td>2.125176</td>\n",
       "      <td>0.325208</td>\n",
       "      <td>0.659375</td>\n",
       "      <td>0.670833</td>\n",
       "      <td>0.689583</td>\n",
       "      <td>0.653125</td>\n",
       "      <td>0.325208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>1.702100</td>\n",
       "      <td>1.925803</td>\n",
       "      <td>0.410000</td>\n",
       "      <td>0.703125</td>\n",
       "      <td>0.721250</td>\n",
       "      <td>0.733542</td>\n",
       "      <td>0.680417</td>\n",
       "      <td>0.410000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.966200</td>\n",
       "      <td>2.273827</td>\n",
       "      <td>0.423750</td>\n",
       "      <td>0.720417</td>\n",
       "      <td>0.717292</td>\n",
       "      <td>0.742708</td>\n",
       "      <td>0.704375</td>\n",
       "      <td>0.423750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.097400</td>\n",
       "      <td>5.073913</td>\n",
       "      <td>0.415833</td>\n",
       "      <td>0.713125</td>\n",
       "      <td>0.717708</td>\n",
       "      <td>0.737292</td>\n",
       "      <td>0.686458</td>\n",
       "      <td>0.415833</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: mbti_lora_qwen2.5-1.5b_pandora_ckpt/confusion_matrix.png\n",
      "Saved: mbti_lora_qwen2.5-1.5b_pandora_ckpt/roc_micro_macro.png\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1200' max='1200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1200/1200 01:36]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Final Eval ===\n",
      "eval_loss: 2.2738\n",
      "eval_acc_16: 0.4238\n",
      "eval_acc_ei: 0.7204\n",
      "eval_acc_ns: 0.7173\n",
      "eval_acc_tf: 0.7427\n",
      "eval_acc_jp: 0.7044\n",
      "eval_acc_4D: 0.4238\n",
      "eval_runtime: 96.8746\n",
      "eval_samples_per_second: 49.5490\n",
      "eval_steps_per_second: 12.3870\n",
      "epoch: 4.0000\n",
      "\n",
      "✅ LoRA adapter saved to: mbti_lora_qwen2.5-1.5b_pandora_ckpt\n",
      "\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*- \n",
    "\"\"\"\n",
    "Qwen-1.5B-Instruct（Qwen2.5） + QLoRA(4bit) + LoRA\n",
    "MBTI 16类分类（含4D严格准确率）\n",
    "Transformers==4.55\n",
    "GPU-only（绝不使用CPU/磁盘offload）\n",
    "\"\"\"\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")  # 服务器/无显示环境\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay,\n",
    "    roc_curve, auc\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "import torch.nn.functional as F\n",
    "import os\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "# 禁用 accelerate 混精，关 bnb 欢迎语\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 建议先关 flash-attn（环境确认后再打开）\n",
    "# try:\n",
    "#     from transformers import set_attn_implementation\n",
    "#     set_attn_implementation(\"flash_attention_2\")\n",
    "# except Exception:\n",
    "#     pass\n",
    "\n",
    "# ============ 配置 ============\n",
    "# 模型改为 Qwen 1.5B Instruct（推荐 Qwen2.5）\n",
    "MODEL_NAME   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "# 新数据集文件名\n",
    "FILE_PATH    = \"mbti_sample_with_all_views_pandora.json\"\n",
    "\n",
    "# 24GB 显存友好（等效 tokens/step ≈ 3840）\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 8           # per-device\n",
    "BSZ_EVAL     = 4\n",
    "GRAD_ACCUM   = 1           # 320 * 2 * 6 = 3840\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "# 输出目录改名\n",
    "OUTPUT_DIR   = \"mbti_lora_qwen2.5-1.5b_pandora_ckpt\"\n",
    "\n",
    "# QLoRA & LoRA\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ============ MBTI 工具 ============\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    return [r for r in rows if r.get(\"type\") in MBTI2ID]\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "# ============ Dataset ============\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ============ 指标 ============\n",
    "def compute_metrics(eval_pred):\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    if not isinstance(preds, np.ndarray): preds = np.asarray(preds)\n",
    "    if not isinstance(labels, np.ndarray): labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir):\n",
    "    \"\"\"\n",
    "    y_true: shape (N,)\n",
    "    y_prob: shape (N, C) — softmax 后的概率\n",
    "    class_names: 长度 C 的类别名列表\n",
    "    \"\"\"\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    # ===== 混淆矩阵 =====\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(\"Confusion Matrix\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, \"confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ===== 多分类 ROC（micro/macro）=====\n",
    "    # 将 y_true 进行 one-vs-rest 二值化\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))  # (N, C)\n",
    "    fpr = dict(); tpr = dict(); roc_auc = dict()\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    # micro\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    # macro（各类 AUC 的算术平均）\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    # 只画 micro/macro（可读性更好；若需要每类曲线可再加）\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(\"Multiclass ROC (micro & macro)\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, \"roc_micro_macro.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ============ 主流程 ============\n",
    "def main():\n",
    "    torch.cuda.set_device(0)  # 显式选择 GPU\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        rows, test_size=0.1, random_state=SEED, stratify=[r[\"type\"] for r in rows]\n",
    "    )\n",
    "\n",
    "    # tokenizer（Qwen 推荐 trust_remote_code）\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        MODEL_NAME, use_fast=True, trust_remote_code=True, **HF_KW\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # QLoRA 4bit（GPU-only）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 模型加载到 GPU（禁止自动下放）\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},    # 强制整模在单卡\n",
    "        low_cpu_mem_usage=True,\n",
    "        trust_remote_code=True,       # Qwen 推荐加\n",
    "        **HF_KW,\n",
    "    )\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "\n",
    "    # 禁止 resize 词表（避免未量化大embedding + 设备错放）\n",
    "    # model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    # LoRA（先准备再统一迁移到 GPU）\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)\n",
    "    try:\n",
    "        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n",
    "    except Exception:\n",
    "        pass\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "\n",
    "    # 统一放到 GPU（只指定设备，不改 dtype）\n",
    "    model = model.to(\"cuda:0\")\n",
    "\n",
    "    # （可选）此时再屏蔽 .to，防误触 dtype cast\n",
    "    def _noop_to(self, *args, **kwargs): return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    # 数据 & collator\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 训练参数（GPU-only 友好）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,      # 8\n",
    "        per_device_eval_batch_size=BSZ_EVAL,      # 4\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,   # 1\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False, fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        eval_accumulation_steps=12,\n",
    "        gradient_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()\n",
    "    # 训练结束后，拿验证集预测的 logits\n",
    "    pred_output = trainer.predict(val_ds)\n",
    "    logits = pred_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    # 概率：softmax\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = pred_output.label_ids\n",
    "\n",
    "    # 画图并保存\n",
    "    plot_confusion_and_roc(\n",
    "        y_true=y_true,\n",
    "        y_prob=probs,\n",
    "        class_names=MBTI_16,\n",
    "        out_dir=OUTPUT_DIR\n",
    "    )\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    eval_metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Final Eval ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 推理示例（确保同一设备）\n",
    "    model.eval()\n",
    "    sample = val_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9987aff9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 18,489,344 || all params: 1,561,811,968 || trainable%: 1.1838\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2833968/2406488015.py:257: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='21600' max='21600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [21600/21600 3:26:45, Epoch 4/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.986800</td>\n",
       "      <td>1.984458</td>\n",
       "      <td>0.374167</td>\n",
       "      <td>0.678125</td>\n",
       "      <td>0.698750</td>\n",
       "      <td>0.712292</td>\n",
       "      <td>0.685000</td>\n",
       "      <td>0.374167</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.029900</td>\n",
       "      <td>1.729987</td>\n",
       "      <td>0.463958</td>\n",
       "      <td>0.737917</td>\n",
       "      <td>0.750417</td>\n",
       "      <td>0.753750</td>\n",
       "      <td>0.717917</td>\n",
       "      <td>0.463958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.471900</td>\n",
       "      <td>2.178578</td>\n",
       "      <td>0.479583</td>\n",
       "      <td>0.740000</td>\n",
       "      <td>0.752083</td>\n",
       "      <td>0.767292</td>\n",
       "      <td>0.734167</td>\n",
       "      <td>0.479583</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.185700</td>\n",
       "      <td>3.985021</td>\n",
       "      <td>0.483958</td>\n",
       "      <td>0.747292</td>\n",
       "      <td>0.760000</td>\n",
       "      <td>0.760833</td>\n",
       "      <td>0.736042</td>\n",
       "      <td>0.483958</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: mbti_lora_qwen2.5-1.5b_pandora_new_ckpt/confusion_matrix.png\n",
      "Saved: mbti_lora_qwen2.5-1.5b_pandora_new_ckpt/roc_micro_macro.png\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='600' max='600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [600/600 01:39]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Final Eval ===\n",
      "eval_loss: 3.9850\n",
      "eval_acc_16: 0.4840\n",
      "eval_acc_ei: 0.7473\n",
      "eval_acc_ns: 0.7600\n",
      "eval_acc_tf: 0.7608\n",
      "eval_acc_jp: 0.7360\n",
      "eval_acc_4D: 0.4840\n",
      "eval_runtime: 100.1548\n",
      "eval_samples_per_second: 47.9260\n",
      "eval_steps_per_second: 5.9910\n",
      "epoch: 4.0000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "✅ LoRA adapter saved to: mbti_lora_qwen2.5-1.5b_pandora_new_ckpt\n",
      "\n",
      "原标签: INTJ | 预测: ISTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Qwen-2.5-1.5B + LoRA 训练 MBTI 16类，同时统计4D严格准确率\n",
    "（不做量化，避免 .to() / bitsandbytes 的兼容问题；适配 transformers 4.55+）\n",
    "并按指定格式输出混淆矩阵和ROC两张图：\n",
    "- confusion_matrix.png\n",
    "- roc_micro_macro.png\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any, List\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# ========= 新增：绘图相关依赖（保持你要的输出格式）=========\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")  # 服务器/无显示环境\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "import torch.nn.functional as F\n",
    "# ======================================================\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ---------------- 基本配置 ----------------\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B\"              # 也可用 \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "FILE_PATH  = \"mbti_sample_with_all_views_pandora.json\"\n",
    "\n",
    "MAX_LEN = 768\n",
    "BUDGET = {\n",
    "    \"posts_cleaned\": 384,\n",
    "    \"semantic_view\": 128,\n",
    "    \"sentiment_view\": 128,\n",
    "    \"linguistic_view\": 128,\n",
    "}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 4           # 显存吃紧就再降\n",
    "BSZ_EVAL     = 8\n",
    "GRAD_ACCUM   = 2\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_lora_qwen2.5-1.5b_pandora_new_ckpt\"   # ← 按你要求改名\n",
    "\n",
    "# ---------------- MBTI 工具 ----------------\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\n",
    "    \"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\n",
    "    \"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\",\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    # I/E, S/N, F/T, P/J -> 0/1\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: \"AutoTokenizer\", text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: \"AutoTokenizer\") -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n\"\n",
    "        f\"[SEMANTIC]\\n{sem}\\n\"\n",
    "        f\"[SENTIMENT]\\n{sen}\\n\"\n",
    "        f\"[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str) -> List[Dict[str, Any]]:\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    return [r for r in rows if r.get(\"type\") in MBTI2ID]\n",
    "\n",
    "# ---------------- Dataset ----------------\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ---------------- 指标（沿用旧逻辑） ----------------\n",
    "def compute_metrics(eval_pred):\n",
    "    # 兼容 EvalPrediction / (preds, labels) / logits tuple\n",
    "    preds, labels = (eval_pred.predictions, eval_pred.label_ids) if hasattr(eval_pred, \"predictions\") else eval_pred\n",
    "    if isinstance(preds, tuple):\n",
    "        preds = preds[0]\n",
    "    if isinstance(preds, torch.Tensor):\n",
    "        preds = preds.detach().cpu().numpy()\n",
    "    if isinstance(labels, torch.Tensor):\n",
    "        labels = labels.detach().cpu().numpy()\n",
    "\n",
    "    preds = preds.argmax(-1)\n",
    "    acc16 = float((preds == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in preds]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\n",
    "        \"acc_16\": acc16,\n",
    "        \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n,\n",
    "        \"acc_4D\": c_all/n,\n",
    "    }\n",
    "\n",
    "# ---------------- 绘图函数（保持你要的格式/文件名） ----------------\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    # 混淆矩阵\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(\"Confusion Matrix\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, \"confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC（micro / macro）\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(\"Multiclass ROC (micro & macro)\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, \"roc_micro_macro.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ---------------- 训练主流程 ----------------\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        rows, test_size=0.1, random_state=SEED, stratify=[r[\"type\"] for r in rows]\n",
    "    )\n",
    "\n",
    "    # Tokenizer（为稳妥可加 trust_remote_code）\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, trust_remote_code=True)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 模型（不量化）+ 分类头\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else None),\n",
    "        device_map=\"auto\",\n",
    "        trust_remote_code=True,\n",
    "    )\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "    model.resize_token_embeddings(len(tokenizer))\n",
    "    if hasattr(model, \"gradient_checkpointing_enable\"):\n",
    "        model.gradient_checkpointing_enable()\n",
    "\n",
    "    # 只做 LoRA（不做 k-bit）\n",
    "    from peft import LoraConfig, TaskType, get_peft_model\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=16, lora_alpha=32, lora_dropout=0.05,\n",
    "        target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n",
    "        bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # transformers 4.55+ 使用 eval_strategy\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        logging_steps=50,\n",
    "        save_total_limit=2,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        bf16=(torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8),\n",
    "        fp16=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()\n",
    "\n",
    "    # ====== 生成与保存图表（保持你的输出格式）======\n",
    "    pred_output = trainer.predict(val_ds)\n",
    "    logits = pred_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = pred_output.label_ids\n",
    "\n",
    "    plot_confusion_and_roc(\n",
    "        y_true=y_true,\n",
    "        y_prob=probs,\n",
    "        class_names=MBTI_16,\n",
    "        out_dir=OUTPUT_DIR\n",
    "    )\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    eval_metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Final Eval ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 简单推理示例\n",
    "    model.eval()\n",
    "    sample = val_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(model.device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n原标签:\", sample[\"type\"], \"| 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "77939842",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base_model_name_or_path = Qwen/Qwen2.5-1.5B-Instruct\n",
      "task_type = SEQ_CLS\n"
     ]
    }
   ],
   "source": [
    "import json, os\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b_ckpt\"   # 你的 ckpt 目录\n",
    "cfg = json.load(open(os.path.join(ADAPTER_DIR, \"adapter_config.json\"), \"r\"))\n",
    "print(\"base_model_name_or_path =\", cfg.get(\"base_model_name_or_path\"))\n",
    "print(\"task_type =\", cfg.get(\"task_type\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6b04b6fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden_size = 1536\n",
      "score.weight shape = (16, 1536)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2833968/3076674525.py:248: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6000' max='6000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6000/6000 13:27]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== Pandora Eval ===\n",
      "eval_loss: 3.4766\n",
      "eval_model_preparation_time: 0.0030\n",
      "eval_acc_16: 0.1066\n",
      "eval_acc_ei: 0.5264\n",
      "eval_acc_ns: 0.5206\n",
      "eval_acc_tf: 0.6357\n",
      "eval_acc_jp: 0.5229\n",
      "eval_acc_4D: 0.1066\n",
      "eval_runtime: 807.8859\n",
      "eval_samples_per_second: 59.4140\n",
      "eval_steps_per_second: 7.4270\n",
      "样例原标签: INTP | 预测: INFJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "评测脚本（最终版，无 AutoPeft）\n",
    "- 基座: Qwen/Qwen2.5-1.5B-Instruct\n",
    "- LoRA: mbti_lora_qwen1.5b_ckpt\n",
    "- 数据: mbti_sample_with_all_views_pandora.json\n",
    "- 输出: 指标 + 混淆矩阵 + ROC(micro/macro)\n",
    "- 纯评测（不训练），GPU-only，4bit 量化\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from peft import PeftModel  # 只用 PeftModel，不用 AutoPeft\n",
    "\n",
    "# ================== 配置 ==================\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"                # 必须与 adapter_config.json 对齐\n",
    "ADAPTER_DIR  = \"mbti_lora_qwen1.5b_ckpt\"                   # 你的 LoRA 目录\n",
    "FILE_PATH    = \"mbti_sample_with_all_views_pandora.json\"   # 新数据集（潘多拉）\n",
    "OUTPUT_DIR   = \"eval_on_pandora_outputs\"                   # 评测输出目录\n",
    "\n",
    "MAX_LEN      = 320\n",
    "USE_4BIT     = True\n",
    "SEED         = 42\n",
    "NUM_LABELS   = 16\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# 与训练一致的多视角 token 预算\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ================== 工具函数 ==================\n",
    "def mbti_to_4d(m: str):\n",
    "    m = m.upper()\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    \"\"\"\n",
    "    潘多拉与原数据字段若一致则直接用；\n",
    "    若是最简 {\"text\": \"...\", \"label\": \"...\"}，把 posts_cleaned 改成 text 即可。\n",
    "    \"\"\"\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    preds, labels = (eval_pred if isinstance(eval_pred, tuple)\n",
    "                     else (eval_pred.predictions, eval_pred.label_ids))\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    preds = np.asarray(preds); labels = np.asarray(labels)\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(\"Confusion Matrix (Pandora)\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, \"pandora_confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC：跳过评测集中没有正样本的类\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr, tpr, roc_auc = {}, {}, {}\n",
    "    valid = []\n",
    "    for i in range(len(class_names)):\n",
    "        if Y_true_bin[:, i].sum() == 0:\n",
    "            continue\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        valid.append(i)\n",
    "    if len(valid) >= 2:\n",
    "        fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(\n",
    "            Y_true_bin[:, valid].ravel(), y_prob[:, valid].ravel()\n",
    "        )\n",
    "        roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "        all_fpr = np.unique(np.concatenate([fpr[i] for i in valid]))\n",
    "        mean_tpr = np.zeros_like(all_fpr)\n",
    "        for i in valid:\n",
    "            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "        mean_tpr /= len(valid)\n",
    "        fpr[\"macro\"] = all_fpr; tpr[\"macro\"] = mean_tpr\n",
    "        roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "        fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "        ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                    label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                    label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "        ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "        ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "        ax_roc.set_title(\"Multiclass ROC (Pandora)\")\n",
    "        ax_roc.legend(loc=\"lower right\")\n",
    "        fig_roc.tight_layout()\n",
    "        fig_roc.savefig(os.path.join(out_dir, \"pandora_roc_micro_macro.png\"))\n",
    "        plt.close(fig_roc)\n",
    "\n",
    "# ================== 主流程 ==================\n",
    "def main():\n",
    "    # 纯 GPU 环境\n",
    "    os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "    os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"  # 注意拼写\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # tokenizer 与基座一致\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        BASE_MODEL, use_fast=True, trust_remote_code=True, **HF_KW\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 量化配置\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,  # 老显卡可改 torch.float16\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # ===== 关键：先用 num_labels=16 初始化“基座分类模型”，再套 LoRA =====\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True, **HF_KW)\n",
    "    base_cfg.num_labels = NUM_LABELS\n",
    "\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL,\n",
    "        config=base_cfg,                         # 带 num_labels=16 的 config\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        quantization_config=quant_cfg,\n",
    "        trust_remote_code=True,\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "\n",
    "    model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False)\n",
    "    model.config.use_cache = False\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "\n",
    "    # （可选）形状检查\n",
    "    try:\n",
    "        print(\"hidden_size =\", model.base_model.model.config.hidden_size)\n",
    "        print(\"score.weight shape =\", tuple(model.base_model.model.score.weight.shape))\n",
    "    except Exception:\n",
    "        pass\n",
    "\n",
    "    # 数据\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    eval_ds = MBTIDataset(rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 只评测\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_eval_batch_size=8,\n",
    "        dataloader_drop_last=False,\n",
    "        report_to=\"none\",\n",
    "        fp16=False, bf16=False,\n",
    "    )\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        eval_dataset=eval_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # 预测 → 概率 → 作图\n",
    "    pred_output = trainer.predict(eval_ds)\n",
    "    logits = pred_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = pred_output.label_ids\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR)\n",
    "\n",
    "    # 指标（含四维与整体）\n",
    "    metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Pandora Eval ===\")\n",
    "    for k, v in metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    # 简单样例\n",
    "    model.eval()\n",
    "    sample = rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        out = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(out, dim=-1))\n",
    "        print(\"样例原标签:\", sample[\"type\"], \"| 预测:\", MBTI_16[pred_id])\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "deef70fd",
   "metadata": {},
   "source": [
    "# 融合训练 单个评测kaggle/pandora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e48b1dd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 18,489,344 || all params: 1,562,228,224 || trainable%: 1.1835\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3024660/529486632.py:326: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='501' max='30762' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [  501/30762 04:25 < 4:28:07, 1.88 it/s, Epoch 0.05/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4615' max='4800' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4615/4800 12:30 < 00:30, 6.15 it/s]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "训练 + 评测（LoRA / 4bit / 单卡）\n",
    "- 基座: Qwen/Qwen2.5-1.5B-Instruct\n",
    "- 训练集: 两数据集合并 (mbti_sample_with_all_views.json + mbti_sample_with_all_views_pandora.json)\n",
    "- Eval/Test: 只在指定的数据集上评测 (默认 Pandora)\n",
    "- 输出: 指标 + 混淆矩阵 + ROC(micro/macro) + LoRA 适配器权重\n",
    "\"\"\"\n",
    "import os, json, random\n",
    "from typing import Dict, Any, List\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "from peft import (\n",
    "    LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel\n",
    ")\n",
    "\n",
    "# ================== 配置 ==================\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "DATA_A       = \"mbti_sample_with_all_views.json\"          # 旧数据\n",
    "DATA_B       = \"mbti_sample_with_all_views_pandora.json\"  # Pandora\n",
    "EVAL_ON      = \"B\"  # 只在哪个集上做 eval/test: \"A\" or \"B\"\n",
    "OUTPUT_DIR   = \"qwen-test-on-pandora\"                      # 输出目录（含 lora）\n",
    "RESUME_ADAPTER_DIR = None  # 若已有 LoRA 断点，可填入目录；否则置为 None\n",
    "\n",
    "MAX_LEN      = 320\n",
    "USE_4BIT     = True\n",
    "SEED         = 42\n",
    "NUM_LABELS   = 16\n",
    "\n",
    "# LoRA 超参（可按需微调）\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "# Qwen2.5 常用目标模块\n",
    "LORA_TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "# 训练超参（按你的显存情况调整）\n",
    "BATCH_SIZE_PER_DEVICE_TRAIN = 8\n",
    "BATCH_SIZE_PER_DEVICE_EVAL  = 8\n",
    "GR_ACCUM_STEPS              = 1\n",
    "EPOCHS                      = 3\n",
    "LR                          = 2e-4\n",
    "WARMUP_RATIO                = 0.05\n",
    "LOGGING_STEPS               = 20\n",
    "SAVE_STEPS                  = 500\n",
    "EVAL_STEPS                  = 500\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# 与训练一致的多视角 token 预算\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ================== 工具函数 ==================\n",
    "def mbti_to_4d(m: str):\n",
    "    m = m.upper()\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    preds, labels = (eval_pred if isinstance(eval_pred, tuple)\n",
    "                     else (eval_pred.predictions, eval_pred.label_ids))\n",
    "    if isinstance(preds, (list, tuple)): preds = preds[0]\n",
    "    preds = np.asarray(preds); labels = np.asarray(labels)\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, tag=\"eval\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix ({tag})\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"{tag}_confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC：跳过评测集中没有正样本的类\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr, tpr, roc_auc = {}, {}, {}\n",
    "    valid = []\n",
    "    for i in range(len(class_names)):\n",
    "        if Y_true_bin[:, i].sum() == 0:\n",
    "            continue\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        valid.append(i)\n",
    "    if len(valid) >= 2:\n",
    "        fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(\n",
    "            Y_true_bin[:, valid].ravel(), y_prob[:, valid].ravel()\n",
    "        )\n",
    "        roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "        all_fpr = np.unique(np.concatenate([fpr[i] for i in valid]))\n",
    "        mean_tpr = np.zeros_like(all_fpr)\n",
    "        for i in valid:\n",
    "            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "        mean_tpr /= len(valid)\n",
    "        fpr[\"macro\"] = all_fpr; tpr[\"macro\"] = mean_tpr\n",
    "        roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "        fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "        ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                    label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                    label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "        ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "        ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "        ax_roc.set_title(f\"Multiclass ROC ({tag})\")\n",
    "        ax_roc.legend(loc=\"lower right\")\n",
    "        fig_roc.tight_layout()\n",
    "        fig_roc.savefig(os.path.join(out_dir, f\"{tag}_roc_micro_macro.png\"))\n",
    "        plt.close(fig_roc)\n",
    "\n",
    "# ================== 主流程 ==================\n",
    "def main():\n",
    "    # 环境 & 种子\n",
    "    os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "    os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        BASE_MODEL, use_fast=True, trust_remote_code=True, **HF_KW\n",
    "    )\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 量化\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 分类头：num_labels=16\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True, **HF_KW)\n",
    "    base_cfg.num_labels = NUM_LABELS\n",
    "\n",
    "    # 基座\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL,\n",
    "        config=base_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        quantization_config=quant_cfg,\n",
    "        trust_remote_code=True,\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "\n",
    "    # ========= LoRA：新训或续训 =========\n",
    "    if RESUME_ADAPTER_DIR:\n",
    "        # 从已训练的 LoRA 继续\n",
    "        model = PeftModel.from_pretrained(base, RESUME_ADAPTER_DIR, is_trainable=True)\n",
    "    else:\n",
    "        # 新建 LoRA\n",
    "        base = prepare_model_for_kbit_training(base)  # 4bit 可训练准备\n",
    "        lora_cfg = LoraConfig(\n",
    "            r=LORA_R,\n",
    "            lora_alpha=LORA_ALPHA,\n",
    "            target_modules=LORA_TARGET_MODULES,\n",
    "            lora_dropout=LORA_DROPOUT,\n",
    "            bias=\"none\",\n",
    "            task_type=\"SEQ_CLS\",\n",
    "        )\n",
    "        model = get_peft_model(base, lora_cfg)\n",
    "\n",
    "    model.config.use_cache = False\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "    # ========= 数据 =========\n",
    "    rows_A = load_rows(DATA_A)\n",
    "    rows_B = load_rows(DATA_B)\n",
    "\n",
    "    # 训练集 = A ∪ B\n",
    "    train_rows: List[Dict[str, Any]] = rows_A + rows_B\n",
    "    random.Random(SEED).shuffle(train_rows)\n",
    "\n",
    "    # eval/test 只用指定一个数据集（默认 B=Pandora）\n",
    "    if EVAL_ON.upper() == \"A\":\n",
    "        eval_rows = rows_A\n",
    "        eval_tag  = \"A_eval\"\n",
    "    else:\n",
    "        eval_rows = rows_B\n",
    "        eval_tag  = \"B_eval\"\n",
    "\n",
    "    # （可选）从 eval_rows 再划一个 test 子集；这里简单按 80/20 切\n",
    "    cut = int(0.8 * len(eval_rows)) if len(eval_rows) > 5 else len(eval_rows)\n",
    "    test_rows = eval_rows[cut:]\n",
    "    eval_rows = eval_rows[:cut] if cut > 0 else eval_rows\n",
    "\n",
    "    # 构建数据集\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    eval_ds  = MBTIDataset(eval_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN) if test_rows else None\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # ========= 训练参数 =========\n",
    "    from transformers import TrainingArguments\n",
    "\n",
    "    common_kwargs = dict(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BATCH_SIZE_PER_DEVICE_TRAIN,\n",
    "        per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE_EVAL,\n",
    "        gradient_accumulation_steps=GR_ACCUM_STEPS,\n",
    "        learning_rate=LR,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        logging_steps=LOGGING_STEPS,\n",
    "        eval_steps=EVAL_STEPS,\n",
    "        save_steps=SAVE_STEPS,\n",
    "        save_total_limit=2,\n",
    "        lr_scheduler_type=\"cosine\",\n",
    "        report_to=\"none\",\n",
    "        fp16=False, bf16=False,\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_16\",\n",
    "        greater_is_better=True,\n",
    "        # 这两个在新老版本都存在，显式写上更稳\n",
    "        logging_strategy=\"steps\",\n",
    "        save_strategy=\"steps\",\n",
    "    )\n",
    "\n",
    "    # 依次尝试新/旧/远古命名，保证不同版本都能跑\n",
    "    try:\n",
    "        args = TrainingArguments(eval_strategy=\"steps\", **common_kwargs)\n",
    "    except TypeError:\n",
    "        try:\n",
    "            args = TrainingArguments(evaluation_strategy=\"steps\", **common_kwargs)\n",
    "        except TypeError:\n",
    "            # 超老版本（3.x）兜底\n",
    "            args = TrainingArguments(evaluate_during_training=True, **common_kwargs)\n",
    "\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=eval_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # ========= 训练 =========\n",
    "    trainer.train()\n",
    "\n",
    "    # 保存 LoRA（适配器）\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    try:\n",
    "        model.save_pretrained(os.path.join(OUTPUT_DIR, \"lora_adapter\"))\n",
    "    except Exception as e:\n",
    "        print(\"Save adapter failed:\", e)\n",
    "\n",
    "    # ========= Eval（在指定集）=========\n",
    "    eval_output = trainer.predict(eval_ds)\n",
    "    logits = eval_output.predictions\n",
    "    if isinstance(logits, (list, tuple)):\n",
    "        logits = logits[0]\n",
    "    probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = eval_output.label_ids\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}\")\n",
    "\n",
    "    metrics = trainer.evaluate(eval_dataset=eval_ds)\n",
    "    print(\"\\n=== Eval on chosen dataset ===\")\n",
    "    for k, v in metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    # ========= Test（同一数据集的 hold-out 部分）=========\n",
    "    if test_ds and len(test_ds) > 0:\n",
    "        test_output = trainer.predict(test_ds)\n",
    "        logits = test_output.predictions\n",
    "        if isinstance(logits, (list, tuple)):\n",
    "            logits = logits[0]\n",
    "        probs = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "        y_true = test_output.label_ids\n",
    "        plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, tag=f\"{eval_tag}_test\")\n",
    "\n",
    "        # 简单整体准确率\n",
    "        pred_ids = probs.argmax(-1)\n",
    "        acc = float((pred_ids == y_true).mean())\n",
    "        print(f\"\\n=== Test accuracy on chosen dataset: {acc:.4f}\")\n",
    "\n",
    "    # ========= 示例推理 =========\n",
    "    model.eval()\n",
    "    sample = (rows_B[0] if EVAL_ON.upper()==\"B\" else rows_A[0]) if (rows_A and rows_B) else (train_rows[0])\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        out = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(out, dim=-1))\n",
    "        print(\"样例原标签:\", sample[\"type\"], \"| 预测:\", MBTI_16[pred_id])\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aee742d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved: train_used_in_training=82032, pandora_eval_80=38400, pandora_test_20=9600\n"
     ]
    }
   ],
   "source": [
    "import os, json, random\n",
    "\n",
    "SEED = 42\n",
    "DATA_A = \"mbti_sample_with_all_views.json\"\n",
    "DATA_B = \"mbti_sample_with_all_views_pandora.json\"\n",
    "OUT = \"qwen-test-on-pandora/splits\"\n",
    "\n",
    "MBTI_16 = {\"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "           \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"}\n",
    "\n",
    "def load_rows(path):\n",
    "    rows = json.load(open(path, \"r\", encoding=\"utf-8\"))\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI_16:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)   # 不打乱，也不去重，保持文件原始顺序\n",
    "    return clean\n",
    "\n",
    "os.makedirs(OUT, exist_ok=True)\n",
    "\n",
    "rows_A = load_rows(DATA_A)\n",
    "rows_B = load_rows(DATA_B)\n",
    "\n",
    "# 训练时实际用到的“训练集合”：A ∪ B 的全部样本\n",
    "train_rows_all = rows_A + rows_B\n",
    "json.dump(train_rows_all, open(os.path.join(OUT, \"train_used_in_training.json\"), \"w\", encoding=\"utf-8\"),\n",
    "          ensure_ascii=False, indent=2)\n",
    "\n",
    "# 如果你想把“训练时第一轮的打包顺序（随机打乱顺序）”也复刻出来，记录那次 shuffle 的索引即可：\n",
    "idx = list(range(len(train_rows_all)))\n",
    "random.Random(SEED).shuffle(idx)\n",
    "json.dump(idx, open(os.path.join(OUT, \"train_shuffle_index_seed42.json\"), \"w\", encoding=\"utf-8\"))\n",
    "\n",
    "# Pandora 的 eval/test（与你脚本一致：不打乱，按顺序 8:2 切分）\n",
    "cut = int(0.8 * len(rows_B)) if len(rows_B) > 5 else len(rows_B)\n",
    "pandora_eval = rows_B[:cut]\n",
    "pandora_test = rows_B[cut:]\n",
    "\n",
    "json.dump(pandora_eval, open(os.path.join(OUT, \"pandora_eval_80.json\"), \"w\", encoding=\"utf-8\"),\n",
    "          ensure_ascii=False, indent=2)\n",
    "json.dump(pandora_test, open(os.path.join(OUT, \"pandora_test_20.json\"), \"w\", encoding=\"utf-8\"),\n",
    "          ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"Saved: train_used_in_training={len(train_rows_all)}, \"\n",
    "      f\"pandora_eval_80={len(pandora_eval)}, pandora_test_20={len(pandora_test)}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
