{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7d375368",
   "metadata": {},
   "source": [
    "# 消融实验 带解释的原数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96b1346c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "ORIG_PATH = \"/home/hli962/Chunhou_Project/CL/dataset_mbti/mbti_dataset.json\"\n",
    "AUG_PATH  = \"mbti_sample_with_all_views.json\"\n",
    "OUT_PATH  = \"mbti_sample_with_all_views_TOPN.json\"\n",
    "\n",
    "# 若你的 JSON 是 {\"data\":[...]} 结构，则会自动取 data；如果本身就是列表也可\n",
    "def load_list(path):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        obj = json.load(f)\n",
    "    if isinstance(obj, dict) and \"data\" in obj and isinstance(obj[\"data\"], list):\n",
    "        return obj[\"data\"]\n",
    "    if isinstance(obj, list):\n",
    "        return obj\n",
    "    raise ValueError(f\"{path} 应该是列表，或包含 data 为列表的 JSON\")\n",
    "\n",
    "def quick_match_rate(orig_list, cand_list, id_field=\"id\", pair_fields=(\"type\",\"post\")):\n",
    "    \"\"\"可选小工具：估计增强集前N条与原始的重合度（若有可用键）\"\"\"\n",
    "    def key(r):\n",
    "        if id_field and isinstance(r, dict) and r.get(id_field) is not None:\n",
    "            return (\"ID\", str(r[id_field]))\n",
    "        a,b = pair_fields\n",
    "        if isinstance(r, dict) and r.get(a) is not None and r.get(b) is not None:\n",
    "            return (\"PAIR\", f\"{r[a]}||{r[b]}\")\n",
    "        return None\n",
    "\n",
    "    orig_keys = {key(r) for r in orig_list if key(r) is not None}\n",
    "    if not orig_keys:\n",
    "        return None, 0, 0\n",
    "    cand_keys = [key(r) for r in cand_list]\n",
    "    hit = sum(1 for k in cand_keys if k in orig_keys)\n",
    "    return True, hit, len(cand_keys)\n",
    "\n",
    "def main():\n",
    "    orig_list = load_list(ORIG_PATH)\n",
    "    aug_list  = load_list(AUG_PATH)\n",
    "\n",
    "    N = len(orig_list)\n",
    "    topn = aug_list[:N]\n",
    "\n",
    "    # 写出切片结果（保持列表结构；如需 {\"data\": ...} 自行包一层）\n",
    "    with open(OUT_PATH, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(topn, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    print(f\"原始数据条数 N = {N}\")\n",
    "    print(f\"已从增强集保留前 N 条，写入：{OUT_PATH}（共 {len(topn)} 条）\")\n",
    "\n",
    "    # 可选一致性检查\n",
    "    ok, hit, tot = quick_match_rate(orig_list, topn)\n",
    "    if ok:\n",
    "        rate = hit / tot if tot else 0\n",
    "        print(f\"快速匹配命中：{hit}/{tot}（≈ {rate:.1%}）。\"\n",
    "              f\"若命中率很低，说明增强集的前N条未必对应原始样本顺序。\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d087005d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import Counter\n",
    "\n",
    "# ========== 工具函数 ==========\n",
    "def norm_text(s: str) -> str:\n",
    "    \"\"\"规范化 posts_cleaned：去掉多余空格、换行并转小写\"\"\"\n",
    "    s = s.replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\")\n",
    "    s = re.sub(r\"\\s+\", \" \", s)\n",
    "    return s.strip().lower()\n",
    "\n",
    "def value_counts(data, key=\"type\"):\n",
    "    counter = Counter([d.get(key, None) for d in data])\n",
    "    total = sum(counter.values()) or 1\n",
    "    return {k: f\"{v} ({v/total:.2%})\" for k, v in counter.items()}\n",
    "\n",
    "# ========== 1. 读入数据 ==========\n",
    "with open(\"mbti_sample_with_all_views_TOPN.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    full_data = json.load(f)\n",
    "\n",
    "with open(\"test对应的原始数据.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    test_data = json.load(f)\n",
    "\n",
    "print(\"完整数据:\", len(full_data))\n",
    "print(\"测试集:\", len(test_data))\n",
    "\n",
    "# ========== 2. 按 posts_cleaned 剪掉 test ==========\n",
    "test_keys = {norm_text(d[\"posts_cleaned\"]) for d in test_data}\n",
    "dev_data = [d for d in full_data if norm_text(d[\"posts_cleaned\"]) not in test_keys]\n",
    "\n",
    "print(\"去掉测试集后，剩余 dev:\", len(dev_data))\n",
    "\n",
    "# ========== 3. 从 dev 中再划分 train/val ==========\n",
    "X = dev_data\n",
    "y = [d[\"type\"] for d in dev_data]\n",
    "\n",
    "try:\n",
    "    train_data, val_data = train_test_split(\n",
    "        X,\n",
    "        test_size=0.2,       # 验证集比例，可改\n",
    "        random_state=42,\n",
    "        stratify=y           # 分类任务分层\n",
    "    )\n",
    "except ValueError as e:\n",
    "    print(\"⚠️ 分层失败，改为随机拆分:\", e)\n",
    "    train_data, val_data = train_test_split(\n",
    "        X,\n",
    "        test_size=0.2,\n",
    "        random_state=42,\n",
    "        stratify=None\n",
    "    )\n",
    "\n",
    "# ========== 4. 打印统计信息 ==========\n",
    "print(\"\\n样本数分布：\")\n",
    "print(\"Train:\", len(train_data), value_counts(train_data))\n",
    "print(\"Val  :\", len(val_data), value_counts(val_data))\n",
    "print(\"Test :\", len(test_data), value_counts(test_data))\n",
    "\n",
    "# ========== 5. 如需保存成文件 ==========\n",
    "with open(\"train消融.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(train_data, f, ensure_ascii=False, indent=2)\n",
    "with open(\"val消融.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(val_data, f, ensure_ascii=False, indent=2)\n",
    "with open(\"test消融.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(test_data, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"\\n已保存 train.json / val.json / test.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7e1a3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "将 mbti_sample_with_all_views.json 分层切分为 8:1:1 的 train/val/test\n",
    "随机种子固定为 42，确保可复现\n",
    "\"\"\"\n",
    "import json\n",
    "from pathlib import Path\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "def load_rows(path: Path):\n",
    "    with path.open(\"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(\"输入文件中没有合法样本（缺少 'type' 或不在 16 类里）。\")\n",
    "    return rows\n",
    "\n",
    "def save_json(rows, path: Path):\n",
    "    path.parent.mkdir(parents=True, exist_ok=True)\n",
    "    with path.open(\"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(rows, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "def main(\n",
    "    input_file=\"mbti_sample_with_all_views_TOPN.json\",\n",
    "    outdir=\".\",\n",
    "    seed=42\n",
    "):\n",
    "    inp = Path(input_file)\n",
    "    out = Path(outdir)\n",
    "\n",
    "    rows = load_rows(inp)\n",
    "    y = [r[\"type\"] for r in rows]\n",
    "\n",
    "    # 先取 10% 作为 TEST（分层）\n",
    "    trainval_rows, test_rows = train_test_split(\n",
    "        rows, test_size=0.10, random_state=seed, stratify=y\n",
    "    )\n",
    "    # 再从 90% 里切 10% 作为 VAL（占总数 0.1）=> 0.1 / 0.9\n",
    "    y_trainval = [r[\"type\"] for r in trainval_rows]\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        trainval_rows, test_size=0.1111111111, random_state=seed, stratify=y_trainval\n",
    "    )\n",
    "\n",
    "    save_json(train_rows, out / \"train_topn.json\")\n",
    "    save_json(val_rows,   out / \"val_topn.json\")\n",
    "    save_json(test_rows,  out / \"test_topn.json\")\n",
    "\n",
    "    print(f\"✅ 已保存：{out/'train.json'}（{len(train_rows)} 条）\")\n",
    "    print(f\"✅ 已保存：{out/'val.json'}（{len(val_rows)} 条）\")\n",
    "    print(f\"✅ 已保存：{out/'test.json'}（{len(test_rows)} 条）\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6b5d6ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "DeepSeek-R1-Distill-Qwen-1.5B + QLoRA(4bit) + LoRA\n",
    "读取 train_topn.json / val_topn.json / test_topn.json 进行训练与评测（VAL & TEST）\n",
    "Transformers==4.55  (peft / bitsandbytes 按需安装)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay,\n",
    "    roc_curve, auc,\n",
    "    precision_recall_fscore_support, classification_report, balanced_accuracy_score\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 环境开关\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "# ============ 配置 ============\n",
    "MODEL_NAME   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 440\n",
    "BUDGET = {\"posts_cleaned\": 320, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 8\n",
    "BSZ_EVAL     = 4\n",
    "GRAD_ACCUM   = 1\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_lora_qwen1.5b-kaggle_xiaorongExplain_ckpt_new\"\n",
    "\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "# ============ 基础函数 ============\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ============ 评价指标 ============\n",
    "def compute_metrics(eval_pred):\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "    if isinstance(preds, (list, tuple)):\n",
    "        preds = preds[0]\n",
    "    preds = np.asarray(preds)\n",
    "    labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "    bal_acc16 = balanced_accuracy_score(labels, pred_ids)\n",
    "\n",
    "    # 16类 Precision / Recall / F1\n",
    "    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(labels, pred_ids, average=\"micro\", zero_division=0)\n",
    "    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(labels, pred_ids, average=\"macro\", zero_division=0)\n",
    "    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, pred_ids, average=\"weighted\", zero_division=0)\n",
    "\n",
    "    # 原脚本中的 4D 维度指标\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "\n",
    "    return {\n",
    "        \"acc_16\": acc16,\n",
    "        \"bal_acc_16\": bal_acc16,\n",
    "        \"p_micro\": p_micro, \"r_micro\": r_micro, \"f1_micro\": f1_micro,\n",
    "        \"p_macro\": p_macro, \"r_macro\": r_macro, \"f1_macro\": f1_macro,\n",
    "        \"p_weighted\": p_weighted, \"r_weighted\": r_weighted, \"f1_weighted\": f1_weighted,\n",
    "        \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n\n",
    "    }\n",
    "\n",
    "# ============ 作图：混淆矩阵 & ROC ============\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix{suffix}\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro){suffix}\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ============ 向量抽取 & t-SNE ============\n",
    "def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):\n",
    "    # last_hidden_state: [B, T, H]; attention_mask: [B, T]\n",
    "    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # [B, T, 1]\n",
    "    summed = (last_hidden_state * mask).sum(dim=1)                  # [B, H]\n",
    "    denom = mask.sum(dim=1).clamp(min=1e-6)                         # [B, 1]\n",
    "    return summed / denom\n",
    "\n",
    "@torch.no_grad()\n",
    "def extract_embeddings(model, tokenizer, dataset, device=\"cuda:0\", batch_size=4):\n",
    "    from torch.utils.data import DataLoader\n",
    "    dl = DataLoader(dataset, batch_size=batch_size, shuffle=False,\n",
    "                    collate_fn=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8))\n",
    "    embs = []\n",
    "    ys   = []\n",
    "    model.eval()\n",
    "    for batch in dl:\n",
    "        input_ids = batch[\"input_ids\"].to(device)\n",
    "        attention_mask = batch[\"attention_mask\"].to(device)\n",
    "        labels = batch[\"labels\"].cpu().numpy()\n",
    "\n",
    "        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        last_h = out.hidden_states[-1]  # [B,T,H]\n",
    "        vec = mean_pool(last_h, attention_mask)  # [B,H]\n",
    "        embs.append(vec.cpu().numpy())\n",
    "        ys.append(labels)\n",
    "    embs = np.concatenate(embs, axis=0)\n",
    "    ys   = np.concatenate(ys, axis=0)\n",
    "    return embs, ys\n",
    "\n",
    "def plot_tsne(embs, labels, class_names, out_png, title=\"t-SNE (PCA init)\"):\n",
    "    n = len(labels)\n",
    "    k = min(50, embs.shape[1])\n",
    "    embs50 = PCA(n_components=k, random_state=SEED).fit_transform(embs)\n",
    "    perplexity = max(5, min(50, n // 20))  # 经验值\n",
    "    tsne2 = TSNE(n_components=2, init=\"pca\", random_state=SEED, perplexity=perplexity, learning_rate=\"auto\")\n",
    "    X2 = tsne2.fit_transform(embs50)\n",
    "\n",
    "    plt.figure(figsize=(8, 7), dpi=150)\n",
    "    for cid, cname in enumerate(class_names):\n",
    "        idx = (labels == cid)\n",
    "        if idx.sum() == 0:\n",
    "            continue\n",
    "        plt.scatter(X2[idx, 0], X2[idx, 1], s=10, alpha=0.7, label=cname)\n",
    "    plt.title(title)\n",
    "    plt.xticks([]); plt.yticks([])\n",
    "    plt.legend(markerscale=2, fontsize=8, ncol=2, frameon=False)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(out_png)\n",
    "    plt.close()\n",
    "\n",
    "# ============ 主流程 ============\n",
    "def main():\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # 读取切分后的文件\n",
    "    train_rows = load_rows(\"train消融.json\")\n",
    "    val_rows   = load_rows(\"val消融.json\")\n",
    "    test_rows  = load_rows(\"test消融.json\")\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)\n",
    "    try:\n",
    "        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n",
    "    except Exception:\n",
    "        pass\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "    model = model.to(\"cuda:0\")\n",
    "\n",
    "    # 防止Trainer在量化+PEFT时二次 to()\n",
    "    def _noop_to(self, *args, **kwargs): return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False, fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        eval_accumulation_steps=12,\n",
    "        gradient_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # 训练\n",
    "    trainer.train()\n",
    "\n",
    "    # ===== 验证集预测 & 作图 =====\n",
    "    val_output = trainer.predict(val_ds)\n",
    "    val_logits = val_output.predictions[0] if isinstance(val_output.predictions, (list, tuple)) else val_output.predictions\n",
    "    val_probs  = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    val_y_true = val_output.label_ids\n",
    "    plot_confusion_and_roc(val_y_true, val_probs, MBTI_16, OUTPUT_DIR, suffix=\"\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    # 验证集指标（含 P/R/F1）\n",
    "    eval_metrics = trainer.evaluate(eval_dataset=val_ds)\n",
    "    print(\"\\n=== Final Eval (on VAL) ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try: print(f\"{k}: {float(v):.4f}\")\n",
    "        except: print(k, v)\n",
    "\n",
    "    # 保存 VAL classification_report\n",
    "    val_pred_ids = val_probs.argmax(-1)\n",
    "    val_report = classification_report(val_y_true, val_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)\n",
    "    with open(os.path.join(OUTPUT_DIR, \"classification_report_val.txt\"), \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(val_report)\n",
    "    print(\"\\n=== VAL Classification Report ===\\n\", val_report)\n",
    "\n",
    "    # ===== 测试集评测 & 作图 =====\n",
    "    test_output = trainer.predict(test_ds)\n",
    "    test_logits = test_output.predictions[0] if isinstance(test_output.predictions, (list, tuple)) else test_output.predictions\n",
    "    test_probs  = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    test_y_true = test_output.label_ids\n",
    "    plot_confusion_and_roc(test_y_true, test_probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro_test.png')}\")\n",
    "\n",
    "    # 保存 TEST classification_report\n",
    "    test_pred_ids = test_probs.argmax(-1)\n",
    "    test_report = classification_report(test_y_true, test_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)\n",
    "    with open(os.path.join(OUTPUT_DIR, \"classification_report_test.txt\"), \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(test_report)\n",
    "    print(\"\\n=== TEST Classification Report ===\\n\", test_report)\n",
    "\n",
    "    # 测试集 4D 指标（保持原有打印）\n",
    "    acc16 = float((test_pred_ids == test_y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in test_pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in test_y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(test_y_true)\n",
    "    print(\"\\n=== Final Test (held-out TEST) ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "\n",
    "    # ====== 聚类：VAL / TEST（t-SNE）======\n",
    "    val_embs, val_y = extract_embeddings(model, tokenizer, val_ds, device=\"cuda:0\", batch_size=BSZ_EVAL)\n",
    "    plot_tsne(val_embs, val_y, MBTI_16, os.path.join(OUTPUT_DIR, \"tsne_val.png\"), title=\"t-SNE (VAL)\")\n",
    "\n",
    "    test_embs, test_y = extract_embeddings(model, tokenizer, test_ds, device=\"cuda:0\", batch_size=BSZ_EVAL)\n",
    "    plot_tsne(test_embs, test_y, MBTI_16, os.path.join(OUTPUT_DIR, \"tsne_test.png\"), title=\"t-SNE (TEST)\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'tsne_val.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'tsne_test.png')}\")\n",
    "\n",
    "    # 保存 LoRA 适配器\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 推理示例（从 TEST 取一个样本）\n",
    "    model.eval()\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
