{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72ce025e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Qwen2.5-1.5B + QLoRA(4bit) + LoRA\n",
    "对比实验：只用原始文本 vs 原文+三维解释\n",
    "Transformers==4.55\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay,\n",
    "    roc_curve, auc,\n",
    "    precision_recall_fscore_support, classification_report, balanced_accuracy_score\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 环境开关\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "# ============ 实验输入开关 ============\n",
    "# \"posts_only\": 只用原始文本；\"all_views\": 原文+三维解释\n",
    "INPUT_MODE   = \"posts_only\"   # ←← 改成 \"all_views\" 可切回原设定\n",
    "MODEL_NAME   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "\n",
    "# 为公平对比：两种模式下总 token 预算尽量一致（≈312）\n",
    "if INPUT_MODE == \"posts_only\":\n",
    "    BUDGET = {\"posts_cleaned\": 320, \"semantic_view\": 0, \"sentiment_view\": 0, \"linguistic_view\": 0}\n",
    "else:\n",
    "    BUDGET = {\"posts_cleaned\": 320, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MAX_LEN      = 440  # 最终编码 max_length（安全边界，和 Trainer truncation 一致）\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 8\n",
    "BSZ_EVAL     = 4\n",
    "GRAD_ACCUM   = 1\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "\n",
    "TAG = \"postsOnly\" if INPUT_MODE==\"posts_only\" else \"allViews\"\n",
    "OUTPUT_DIR   = f\"mbti_lora_qwen1.5b_{TAG}_new\"\n",
    "\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "# ============ 基础函数 ============\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: \"AutoTokenizer\", text: str, budget: int) -> str:\n",
    "    if budget <= 0:\n",
    "        return \"\"\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: \"AutoTokenizer\") -> str:\n",
    "    # 原始文本（有 posts_cleaned 优先用）\n",
    "    p = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "\n",
    "    if INPUT_MODE == \"posts_only\":\n",
    "        return (\n",
    "            f\"[POSTS]\\n{p}\\n\"\n",
    "            f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "        )\n",
    "    else:\n",
    "        sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "        sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "        lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "        return (\n",
    "            f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "            f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "        )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ============ 指标 ============\n",
    "def compute_metrics(eval_pred):\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "    if isinstance(preds, (list, tuple)):\n",
    "        preds = preds[0]\n",
    "    preds = np.asarray(preds); labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "    bal_acc16 = balanced_accuracy_score(labels, pred_ids)\n",
    "\n",
    "    p_micro, r_micro, f1_micro, _ = precision_recall_fscore_support(labels, pred_ids, average=\"micro\", zero_division=0)\n",
    "    p_macro, r_macro, f1_macro, _ = precision_recall_fscore_support(labels, pred_ids, average=\"macro\", zero_division=0)\n",
    "    p_weighted, r_weighted, f1_weighted, _ = precision_recall_fscore_support(labels, pred_ids, average=\"weighted\", zero_division=0)\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "\n",
    "    return {\n",
    "        \"acc_16\": acc16, \"bal_acc_16\": bal_acc16,\n",
    "        \"p_micro\": p_micro, \"r_micro\": r_micro, \"f1_micro\": f1_micro,\n",
    "        \"p_macro\": p_macro, \"r_macro\": r_macro, \"f1_macro\": f1_macro,\n",
    "        \"p_weighted\": p_weighted, \"r_weighted\": r_weighted, \"f1_weighted\": f1_weighted,\n",
    "        \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n\n",
    "    }\n",
    "\n",
    "# ============ 可视化 ============\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix{suffix}\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro){suffix}\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor):\n",
    "    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)\n",
    "    summed = (last_hidden_state * mask).sum(dim=1)\n",
    "    denom = mask.sum(dim=1).clamp(min=1e-6)\n",
    "    return summed / denom\n",
    "\n",
    "@torch.no_grad()\n",
    "def extract_embeddings(model, tokenizer, dataset, device=\"cuda:0\", batch_size=4):\n",
    "    from torch.utils.data import DataLoader\n",
    "    dl = DataLoader(dataset, batch_size=batch_size, shuffle=False,\n",
    "                    collate_fn=DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8))\n",
    "    embs, ys = [], []\n",
    "    model.eval()\n",
    "    for batch in dl:\n",
    "        input_ids = batch[\"input_ids\"].to(device)\n",
    "        attention_mask = batch[\"attention_mask\"].to(device)\n",
    "        labels = batch[\"labels\"].cpu().numpy()\n",
    "        out = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        vec = mean_pool(out.hidden_states[-1], attention_mask)\n",
    "        embs.append(vec.cpu().numpy()); ys.append(labels)\n",
    "    return np.concatenate(embs, 0), np.concatenate(ys, 0)\n",
    "\n",
    "def plot_tsne(embs, labels, class_names, out_png, title=\"t-SNE (PCA init)\"):\n",
    "    n = len(labels)\n",
    "    k = min(50, embs.shape[1])\n",
    "    embs50 = PCA(n_components=k, random_state=SEED).fit_transform(embs)\n",
    "    perplexity = max(5, min(50, n // 20))\n",
    "    tsne2 = TSNE(n_components=2, init=\"pca\", random_state=SEED, perplexity=perplexity, learning_rate=\"auto\")\n",
    "    X2 = tsne2.fit_transform(embs50)\n",
    "\n",
    "    plt.figure(figsize=(8, 7), dpi=150)\n",
    "    for cid, cname in enumerate(class_names):\n",
    "        idx = (labels == cid)\n",
    "        if idx.sum() == 0: continue\n",
    "        plt.scatter(X2[idx, 0], X2[idx, 1], s=10, alpha=0.7, label=cname)\n",
    "    plt.title(title); plt.xticks([]); plt.yticks([])\n",
    "    plt.legend(markerscale=2, fontsize=8, ncol=2, frameon=False)\n",
    "    plt.tight_layout(); plt.savefig(out_png); plt.close()\n",
    "\n",
    "# ============ 主流程 ============\n",
    "def main():\n",
    "    torch.cuda.set_device(0)\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # 沿用相同切分（公平对比）\n",
    "    train_rows = load_rows(\"train消融.json\")\n",
    "    val_rows   = load_rows(\"val消融.json\")\n",
    "    test_rows  = load_rows(\"test消融.json\")\n",
    "\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": \"cuda:0\"},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW,\n",
    "    )\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)\n",
    "    try:\n",
    "        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={\"use_reentrant\": False})\n",
    "    except Exception:\n",
    "        pass\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg).to(\"cuda:0\")\n",
    "\n",
    "    # 防 Trainer 再次 .to()\n",
    "    def _noop_to(self, *args, **kwargs): return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    test_ds  = MBTIDataset(test_rows,  tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False, fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        eval_accumulation_steps=12,\n",
    "        gradient_checkpointing=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    # 训练\n",
    "    trainer.train()\n",
    "\n",
    "    # ===== VAL =====\n",
    "    val_output = trainer.predict(val_ds)\n",
    "    val_logits = val_output.predictions[0] if isinstance(val_output.predictions, (list, tuple)) else val_output.predictions\n",
    "    val_probs  = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    val_y_true = val_output.label_ids\n",
    "    plot_confusion_and_roc(val_y_true, val_probs, MBTI_16, OUTPUT_DIR, suffix=\"\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro.png')}\")\n",
    "\n",
    "    eval_metrics = trainer.evaluate(eval_dataset=val_ds)\n",
    "    print(\"\\n=== Final Eval (on VAL) ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try: print(f\"{k}: {float(v):.4f}\")\n",
    "        except: print(k, v)\n",
    "\n",
    "    val_pred_ids = val_probs.argmax(-1)\n",
    "    val_report = classification_report(val_y_true, val_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)\n",
    "    with open(os.path.join(OUTPUT_DIR, \"classification_report_val.txt\"), \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(val_report)\n",
    "    print(\"\\n=== VAL Classification Report ===\\n\", val_report)\n",
    "\n",
    "    # ===== TEST =====\n",
    "    test_output = trainer.predict(test_ds)\n",
    "    test_logits = test_output.predictions[0] if isinstance(test_output.predictions, (list, tuple)) else test_output.predictions\n",
    "    test_probs  = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    test_y_true = test_output.label_ids\n",
    "    plot_confusion_and_roc(test_y_true, test_probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'confusion_matrix_test.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'roc_micro_macro_test.png')}\")\n",
    "\n",
    "    test_pred_ids = test_probs.argmax(-1)\n",
    "    test_report = classification_report(test_y_true, test_pred_ids, target_names=MBTI_16, digits=4, zero_division=0)\n",
    "    with open(os.path.join(OUTPUT_DIR, \"classification_report_test.txt\"), \"w\", encoding=\"utf-8\") as f:\n",
    "        f.write(test_report)\n",
    "    print(\"\\n=== TEST Classification Report ===\\n\", test_report)\n",
    "\n",
    "    # 4D 指标打印（与原版一致）\n",
    "    acc16 = float((test_pred_ids == test_y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in test_pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in test_y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(test_y_true)\n",
    "    print(\"\\n=== Final Test (held-out TEST) ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "\n",
    "    # ===== t-SNE 聚类（可选）=====\n",
    "    val_embs, val_y = extract_embeddings(model, tokenizer, val_ds, device=\"cuda:0\", batch_size=BSZ_EVAL)\n",
    "    plot_tsne(val_embs, val_y, MBTI_16, os.path.join(OUTPUT_DIR, \"tsne_val.png\"), title=f\"t-SNE (VAL) - {TAG}\")\n",
    "\n",
    "    test_embs, test_y = extract_embeddings(model, tokenizer, test_ds, device=\"cuda:0\", batch_size=BSZ_EVAL)\n",
    "    plot_tsne(test_embs, test_y, MBTI_16, os.path.join(OUTPUT_DIR, \"tsne_test.png\"), title=f\"t-SNE (TEST) - {TAG}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'tsne_val.png')}\")\n",
    "    print(f\"Saved: {os.path.join(OUTPUT_DIR, 'tsne_test.png')}\")\n",
    "\n",
    "    # 保存 LoRA 适配器\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 推理示例\n",
    "    model.eval()\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(\"cuda:0\") for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
