{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5d5b5e11",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "前8000条数据已保存到 前8000条数据.json\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "# 读取原始数据\n",
    "with open(\"mbti_sample_with_all_views.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# 提取前8000条\n",
    "subset = data[:8675]\n",
    "\n",
    "# 保存为新文件\n",
    "with open(\"YS.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(subset, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"前8000条数据已保存到 前8000条数据.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d61a571c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "在原始数据中找到了 853 条与test匹配的记录\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "# 读入原始数据\n",
    "with open(\"YS.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    original = json.load(f)\n",
    "\n",
    "# 读入test数据\n",
    "with open(\"test.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    test = json.load(f)\n",
    "\n",
    "# 建立一个字典，加快查找速度（用posts_cleaned作为key）\n",
    "original_dict = {item[\"posts_cleaned\"]: item for item in original}\n",
    "\n",
    "# 在原始数据中找到test对应的记录\n",
    "matched = []\n",
    "for t in test:\n",
    "    key = t[\"posts_cleaned\"]\n",
    "    if key in original_dict:\n",
    "        matched.append(original_dict[key])\n",
    "\n",
    "# 保存结果\n",
    "with open(\"test对应的原始数据.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(matched, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"在原始数据中找到了 {len(matched)} 条与test匹配的记录\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "961b7c74",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2891950/3197052987.py:193: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== TEST Results ===\n",
      "acc_16: 0.5969\n",
      "acc_ei: 0.8246  acc_ns: 0.8308  acc_tf: 0.8338  acc_jp: 0.7692  acc_4D: 0.5969\n",
      "Saved figs to: mbti_lora_llama-1b_ckpt/eval_test_only_kaggle_final/confusion_matrix_test.png, mbti_lora_llama-1b_ckpt/eval_test_only_kaggle_final/roc_micro_macro_test.png\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Evaluate LoRA adapter on TEST ONLY.\n",
    "依赖：transformers==4.55, peft, scikit-learn, matplotlib, torch, bitsandbytes(如用4bit)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ======== 需要确认的两处路径 ========\n",
    "CKPT_DIR  = \"mbti_lora_llama-1b_ckpt\"   # 你保存LoRA适配器的目录\n",
    "TEST_JSON = \"picked_balanced_around30.json\"                               # 仅评测测试集\n",
    "\n",
    "# ======== 与训练保持一致的配置 ========\n",
    "#MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "#MODEL_NAME = \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\"\n",
    "MODEL_NAME = \"meta-llama/Llama-3.2-1B\"\n",
    "MAX_LEN      = 400\n",
    "BUDGET = {\"posts_cleaned\": 280, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "USE_4BIT = True\n",
    "SEED = 42\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "OUTPUT_DIR = os.path.join(CKPT_DIR, \"eval_test_only_kaggle_final\")\n",
    "\n",
    "# ======== 工具函数 ========\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix-V:kaggle T:kaggle\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro)-V:kaggle T:kaggle\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ======== 主流程（仅TEST） ========\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # 读取 test.json\n",
    "    test_rows = load_rows(TEST_JSON)\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 基座模型 + 量化（与训练保持一致）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    base_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": device},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW\n",
    "    )\n",
    "    base_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    base_model.config.use_cache = False\n",
    "\n",
    "    # 叠加 LoRA 适配器\n",
    "    from peft import PeftModel\n",
    "    model = PeftModel.from_pretrained(base_model, CKPT_DIR, is_trainable=False)\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # 构建测试数据集\n",
    "    test_ds = MBTIDataset(test_rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 仅推理配置（不训练）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUTPUT_DIR, \"tmp_test\"),\n",
    "        per_device_eval_batch_size=4,\n",
    "        eval_accumulation_steps=12,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n",
    "                      tokenizer=tokenizer, data_collator=collator)\n",
    "\n",
    "    # 预测\n",
    "    output = trainer.predict(test_ds)\n",
    "    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions\n",
    "    probs  = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = output.label_ids\n",
    "\n",
    "    # 指标（16类 + 4D）\n",
    "    pred_ids = logits.argmax(-1)\n",
    "    acc16 = float((pred_ids == y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(y_true)\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "\n",
    "    print(\"\\n=== TEST Results ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "    print(f\"Saved figs to: {OUTPUT_DIR}/confusion_matrix_test.png, {OUTPUT_DIR}/roc_micro_macro_test.png\")\n",
    "\n",
    "    # 推理示例\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        plogits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(plogits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bc8d9731",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2891950/3216917669.py:191: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== TEST Results ===\n",
      "acc_16: 0.8554\n",
      "acc_ei: 0.9354  acc_ns: 0.9354  acc_tf: 0.9569  acc_jp: 0.9323  acc_4D: 0.8554\n",
      "Saved figs to: qwen-test-on-pandora_new/lora_adapter/kaggle测kaggle/confusion_matrix_test.png, qwen-test-on-pandora_new/lora_adapter/kaggle测kaggle/roc_micro_macro_test.png\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Evaluate LoRA adapter on TEST ONLY.\n",
    "依赖：transformers==4.55, peft, scikit-learn, matplotlib, torch, bitsandbytes(如用4bit)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ======== 需要确认的两处路径 ========\n",
    "CKPT_DIR  = \"qwen-test-on-pandora_new/lora_adapter\"   # 你保存LoRA适配器的目录\n",
    "TEST_JSON = \"picked_balanced_around30.json\"                               # 仅评测测试集\n",
    "\n",
    "# ======== 与训练保持一致的配置 ========\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "USE_4BIT = True\n",
    "SEED = 42\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "OUTPUT_DIR = os.path.join(CKPT_DIR, \"kaggle测kaggle\")\n",
    "\n",
    "# ======== 工具函数 ========\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix-V:Kaggle,T:kaggle\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro)-V:Kaggle,T:kaggle\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ======== 主流程（仅TEST） ========\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # 读取 test.json\n",
    "    test_rows = load_rows(TEST_JSON)\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 基座模型 + 量化（与训练保持一致）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    base_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": device},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW\n",
    "    )\n",
    "    base_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    base_model.config.use_cache = False\n",
    "\n",
    "    # 叠加 LoRA 适配器\n",
    "    from peft import PeftModel\n",
    "    model = PeftModel.from_pretrained(base_model, CKPT_DIR, is_trainable=False)\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # 构建测试数据集\n",
    "    test_ds = MBTIDataset(test_rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 仅推理配置（不训练）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUTPUT_DIR, \"tmp_test\"),\n",
    "        per_device_eval_batch_size=4,\n",
    "        eval_accumulation_steps=12,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n",
    "                      tokenizer=tokenizer, data_collator=collator)\n",
    "\n",
    "    # 预测\n",
    "    output = trainer.predict(test_ds)\n",
    "    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions\n",
    "    probs  = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = output.label_ids\n",
    "\n",
    "    # 指标（16类 + 4D）\n",
    "    pred_ids = logits.argmax(-1)\n",
    "    acc16 = float((pred_ids == y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(y_true)\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "\n",
    "    print(\"\\n=== TEST Results ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "    print(f\"Saved figs to: {OUTPUT_DIR}/confusion_matrix_test.png, {OUTPUT_DIR}/roc_micro_macro_test.png\")\n",
    "\n",
    "    # 推理示例\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        plogits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(plogits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "15820b24",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2891950/2215564017.py:195: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== TEST Results ===\n",
      "acc_16: 0.8554\n",
      "acc_ei: 0.9354  acc_ns: 0.9354  acc_tf: 0.9569  acc_jp: 0.9323  acc_4D: 0.8554\n",
      "F1-16(micro/macro/weighted): 0.8554 / 0.8524 / 0.8559\n",
      "Recall-16(micro/macro/weighted): 0.8554 / 0.8437 / 0.8554\n",
      "F1-4D(overall micro/macro): 0.9408 / 0.9375\n",
      "Recall-4D(overall micro/macro): 0.9466 / 0.9402\n",
      "[EI]  F1: 0.9170  Recall: 0.8855\n",
      "[NS]  F1: 0.9526  Recall: 0.9814\n",
      "[TF]  F1: 0.9591  Recall: 0.9591\n",
      "[JP]  F1: 0.9214  Recall: 0.9348\n",
      "Saved figs to: qwen-test-on-pandora_new/lora_adapter/kaggle测kaggle/confusion_matrix_test.png, qwen-test-on-pandora_new/lora_adapter/kaggle测kaggle/roc_micro_macro_test.png\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*- \n",
    "\"\"\"\n",
    "Evaluate LoRA adapter on TEST ONLY.\n",
    "依赖：transformers==4.55, peft, scikit-learn, matplotlib, torch, bitsandbytes(如用4bit)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc,\n",
    "    f1_score, recall_score\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ======== 需要确认的两处路径 ========\n",
    "CKPT_DIR  = \"qwen-test-on-pandora_new/lora_adapter\"   # 你保存LoRA适配器的目录\n",
    "TEST_JSON = \"picked_balanced_around30.json\"           # 仅评测测试集\n",
    "\n",
    "# ======== 与训练保持一致的配置 ========\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "USE_4BIT = True\n",
    "SEED = 42\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "OUTPUT_DIR = os.path.join(CKPT_DIR, \"kaggle测kaggle\")\n",
    "\n",
    "# ======== 工具函数 ========\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    # E=1/I=0, N=1/S=0, F=1/T=0, P=1/J=0\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix-V:Kaggle,T:Kaggle\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro)-V:Kaggle,T:Kaggle\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ======== 主流程（仅TEST） ========\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # 读取 test.json\n",
    "    test_rows = load_rows(TEST_JSON)\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 基座模型 + 量化（与训练保持一致）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    base_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": device},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW\n",
    "    )\n",
    "    base_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    base_model.config.use_cache = False\n",
    "\n",
    "    # 叠加 LoRA 适配器\n",
    "    from peft import PeftModel\n",
    "    model = PeftModel.from_pretrained(base_model, CKPT_DIR, is_trainable=False)\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # 构建测试数据集\n",
    "    test_ds = MBTIDataset(test_rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 仅推理配置（不训练）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUTPUT_DIR, \"tmp_test\"),\n",
    "        per_device_eval_batch_size=4,\n",
    "        eval_accumulation_steps=12,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n",
    "                      tokenizer=tokenizer, data_collator=collator)\n",
    "\n",
    "    # 预测\n",
    "    output = trainer.predict(test_ds)\n",
    "    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions\n",
    "    probs  = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = output.label_ids\n",
    "\n",
    "    # 指标（16类 + 4D）\n",
    "    pred_ids = logits.argmax(-1)\n",
    "    acc16 = float((pred_ids == y_true).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in y_true]\n",
    "\n",
    "    # 4D 准确率（与你原来一致）\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    # 4D 的二分类标签收集（0/1）\n",
    "    ei_t, ns_t, tf_t, jp_t = [], [], [], []\n",
    "    ei_p, ns_p, tf_p, jp_p = [], [], [], []\n",
    "\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        # 准确率计数\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "        # 记录二分类标签\n",
    "        ei_t.append(tei); ns_t.append(tns); tf_t.append(ttf); jp_t.append(tjp)\n",
    "        ei_p.append(pei); ns_p.append(pns); tf_p.append(ptf); jp_p.append(pjp)\n",
    "\n",
    "    n = len(y_true)\n",
    "\n",
    "    # ===== 新增：16类整体 F1 / Recall =====\n",
    "    f1_micro_16     = f1_score(y_true, pred_ids, average=\"micro\")\n",
    "    f1_macro_16     = f1_score(y_true, pred_ids, average=\"macro\")\n",
    "    f1_weighted_16  = f1_score(y_true, pred_ids, average=\"weighted\")\n",
    "\n",
    "    rec_micro_16    = recall_score(y_true, pred_ids, average=\"micro\")\n",
    "    rec_macro_16    = recall_score(y_true, pred_ids, average=\"macro\")\n",
    "    rec_weighted_16 = recall_score(y_true, pred_ids, average=\"weighted\")\n",
    "\n",
    "    # ===== 新增：四个维度的二分类 F1 / Recall（正类统一取 1，对应 E/N/F/P）=====\n",
    "    ei_f1  = f1_score(ei_t, ei_p, average=\"binary\", pos_label=1)\n",
    "    ns_f1  = f1_score(ns_t, ns_p, average=\"binary\", pos_label=1)\n",
    "    tf_f1  = f1_score(tf_t, tf_p, average=\"binary\", pos_label=1)\n",
    "    jp_f1  = f1_score(jp_t, jp_p, average=\"binary\", pos_label=1)\n",
    "\n",
    "    ei_rec = recall_score(ei_t, ei_p, average=\"binary\", pos_label=1)\n",
    "    ns_rec = recall_score(ns_t, ns_p, average=\"binary\", pos_label=1)\n",
    "    tf_rec = recall_score(tf_t, tf_p, average=\"binary\", pos_label=1)\n",
    "    jp_rec = recall_score(jp_t, jp_p, average=\"binary\", pos_label=1)\n",
    "\n",
    "    # ===== 新增：4D 的总体分数 =====\n",
    "    # micro：把四个维度的标签都拼接在一起计算\n",
    "    y4_true = np.concatenate([ei_t, ns_t, tf_t, jp_t])\n",
    "    y4_pred = np.concatenate([ei_p, ns_p, tf_p, jp_p])\n",
    "    f1_micro_4d  = f1_score(y4_true, y4_pred, average=\"binary\", pos_label=1)\n",
    "    rec_micro_4d = recall_score(y4_true, y4_pred, average=\"binary\", pos_label=1)\n",
    "\n",
    "    # macro：四个维度分数的平均\n",
    "    f1_macro_4d  = float(np.mean([ei_f1, ns_f1, tf_f1, jp_f1]))\n",
    "    rec_macro_4d = float(np.mean([ei_rec, ns_rec, tf_rec, jp_rec]))\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "\n",
    "    print(\"\\n=== TEST Results ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "\n",
    "    # 16类总体\n",
    "    print(f\"F1-16(micro/macro/weighted): {f1_micro_16:.4f} / {f1_macro_16:.4f} / {f1_weighted_16:.4f}\")\n",
    "    print(f\"Recall-16(micro/macro/weighted): {rec_micro_16:.4f} / {rec_macro_16:.4f} / {rec_weighted_16:.4f}\")\n",
    "\n",
    "    # 4D 总体（把四个二分类合在一起的 micro，以及四维平均的 macro）\n",
    "    print(f\"F1-4D(overall micro/macro): {f1_micro_4d:.4f} / {f1_macro_4d:.4f}\")\n",
    "    print(f\"Recall-4D(overall micro/macro): {rec_micro_4d:.4f} / {rec_macro_4d:.4f}\")\n",
    "\n",
    "    # 4D 各维度\n",
    "    print(f\"[EI]  F1: {ei_f1:.4f}  Recall: {ei_rec:.4f}\")\n",
    "    print(f\"[NS]  F1: {ns_f1:.4f}  Recall: {ns_rec:.4f}\")\n",
    "    print(f\"[TF]  F1: {tf_f1:.4f}  Recall: {tf_rec:.4f}\")\n",
    "    print(f\"[JP]  F1: {jp_f1:.4f}  Recall: {jp_rec:.4f}\")\n",
    "\n",
    "    print(f\"Saved figs to: {OUTPUT_DIR}/confusion_matrix_test.png, {OUTPUT_DIR}/roc_micro_macro_test.png\")\n",
    "\n",
    "    # 推理示例\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        plogits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(plogits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aefe3e73",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2891950/1061593653.py:191: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== TEST Results ===\n",
      "acc_16: 0.8115\n",
      "acc_ei: 0.9044  acc_ns: 0.9125  acc_tf: 0.9135  acc_jp: 0.9044  acc_4D: 0.8115\n",
      "Saved figs to: qwen-test-on-pandora_new/lora_adapter/kaggle测pandora/confusion_matrix_test.png, qwen-test-on-pandora_new/lora_adapter/kaggle测pandora/roc_micro_macro_test.png\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Evaluate LoRA adapter on TEST ONLY.\n",
    "依赖：transformers==4.55, peft, scikit-learn, matplotlib, torch, bitsandbytes(如用4bit)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ======== 需要确认的两处路径 ========\n",
    "CKPT_DIR  = \"qwen-test-on-pandora_new/lora_adapter\"   # 你保存LoRA适配器的目录\n",
    "TEST_JSON = \"pandora_testdataset.json\"                               # 仅评测测试集\n",
    "\n",
    "# ======== 与训练保持一致的配置 ========\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "USE_4BIT = True\n",
    "SEED = 42\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "OUTPUT_DIR = os.path.join(CKPT_DIR, \"kaggle测pandora\")\n",
    "\n",
    "# ======== 工具函数 ========\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix-V:Kaggle,T:Pandora\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro)-V:Kaggle,T:Pandora\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ======== 主流程（仅TEST） ========\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # 读取 test.json\n",
    "    test_rows = load_rows(TEST_JSON)\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 基座模型 + 量化（与训练保持一致）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    base_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": device},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW\n",
    "    )\n",
    "    base_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    base_model.config.use_cache = False\n",
    "\n",
    "    # 叠加 LoRA 适配器\n",
    "    from peft import PeftModel\n",
    "    model = PeftModel.from_pretrained(base_model, CKPT_DIR, is_trainable=False)\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # 构建测试数据集\n",
    "    test_ds = MBTIDataset(test_rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 仅推理配置（不训练）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUTPUT_DIR, \"tmp_test\"),\n",
    "        per_device_eval_batch_size=4,\n",
    "        eval_accumulation_steps=12,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n",
    "                      tokenizer=tokenizer, data_collator=collator)\n",
    "\n",
    "    # 预测\n",
    "    output = trainer.predict(test_ds)\n",
    "    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions\n",
    "    probs  = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = output.label_ids\n",
    "\n",
    "    # 指标（16类 + 4D）\n",
    "    pred_ids = logits.argmax(-1)\n",
    "    acc16 = float((pred_ids == y_true).mean())\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in y_true]\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(y_true)\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "\n",
    "    print(\"\\n=== TEST Results ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "    print(f\"Saved figs to: {OUTPUT_DIR}/confusion_matrix_test.png, {OUTPUT_DIR}/roc_micro_macro_test.png\")\n",
    "\n",
    "    # 推理示例\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        plogits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(plogits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "de9f0480",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2891950/259148812.py:195: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== TEST Results ===\n",
      "acc_16: 0.8115\n",
      "acc_ei: 0.9044  acc_ns: 0.9125  acc_tf: 0.9135  acc_jp: 0.9044  acc_4D: 0.8115\n",
      "F1-16(micro/macro/weighted): 0.8115 / 0.8118 / 0.8118\n",
      "Recall-16(micro/macro/weighted): 0.8115 / 0.8115 / 0.8115\n",
      "F1-4D(overall micro/macro): 0.9088 / 0.9088\n",
      "Recall-4D(overall micro/macro): 0.9097 / 0.9097\n",
      "[EI]  F1: 0.9044  Recall: 0.9046\n",
      "[NS]  F1: 0.9133  Recall: 0.9217\n",
      "[TF]  F1: 0.9133  Recall: 0.9108\n",
      "[JP]  F1: 0.9041  Recall: 0.9017\n",
      "Saved figs to: qwen-test-on-pandora_new/lora_adapter/kaggle测pandora/confusion_matrix_test.png, qwen-test-on-pandora_new/lora_adapter/kaggle测pandora/roc_micro_macro_test.png\n",
      "\n",
      "[Inference on TEST sample]\n",
      "原标签: INTJ  | 预测: INTJ\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*- \n",
    "\"\"\"\n",
    "Evaluate LoRA adapter on TEST ONLY.\n",
    "依赖：transformers==4.55, peft, scikit-learn, matplotlib, torch, bitsandbytes(如用4bit)\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.metrics import (\n",
    "    confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc,\n",
    "    f1_score, recall_score\n",
    ")\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ======== 需要确认的两处路径 ========\n",
    "CKPT_DIR  = \"qwen-test-on-pandora_new/lora_adapter\"   # 你保存LoRA适配器的目录\n",
    "TEST_JSON = \"pandora_testdataset.json\"           # 仅评测测试集\n",
    "\n",
    "# ======== 与训练保持一致的配置 ========\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN      = 320\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t: i for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "USE_4BIT = True\n",
    "SEED = 42\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "OUTPUT_DIR = os.path.join(CKPT_DIR, \"kaggle测pandora\")\n",
    "\n",
    "# ======== 工具函数 ========\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    rows = [r for r in rows if isinstance(r, dict) and r.get(\"type\") in MBTI2ID]\n",
    "    if not rows:\n",
    "        raise ValueError(f\"{path} 中没有合法样本。\")\n",
    "    return rows\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    # E=1/I=0, N=1/S=0, F=1/T=0, P=1/J=0\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, suffix=\"\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix-V:Kaggle,T:Pandora\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"confusion_matrix{suffix}.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr = {}; tpr = {}; roc_auc = {}\n",
    "    for i in range(len(class_names)):\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "\n",
    "    fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(Y_true_bin.ravel(), y_prob.ravel())\n",
    "    roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "\n",
    "    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(class_names))]))\n",
    "    mean_tpr = np.zeros_like(all_fpr)\n",
    "    for i in range(len(class_names)):\n",
    "        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "    mean_tpr /= len(class_names)\n",
    "    fpr[\"macro\"] = all_fpr\n",
    "    tpr[\"macro\"] = mean_tpr\n",
    "    roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "    fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "    ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"], label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"], label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "    ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "    ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "    ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "    ax_roc.set_title(f\"Multiclass ROC (micro & macro)-V:Kaggle,T:Pandora\")\n",
    "    ax_roc.legend(loc=\"lower right\")\n",
    "    fig_roc.tight_layout()\n",
    "    fig_roc.savefig(os.path.join(out_dir, f\"roc_micro_macro{suffix}.png\"))\n",
    "    plt.close(fig_roc)\n",
    "\n",
    "# ======== 主流程（仅TEST） ========\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "    # 读取 test.json\n",
    "    test_rows = load_rows(TEST_JSON)\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 基座模型 + 量化（与训练保持一致）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    base_model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map={\"\": device},\n",
    "        low_cpu_mem_usage=True,\n",
    "        **HF_KW\n",
    "    )\n",
    "    base_model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    base_model.config.use_cache = False\n",
    "\n",
    "    # 叠加 LoRA 适配器\n",
    "    from peft import PeftModel\n",
    "    model = PeftModel.from_pretrained(base_model, CKPT_DIR, is_trainable=False)\n",
    "    model = model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # 构建测试数据集\n",
    "    test_ds = MBTIDataset(test_rows, tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 仅推理配置（不训练）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=os.path.join(OUTPUT_DIR, \"tmp_test\"),\n",
    "        per_device_eval_batch_size=4,\n",
    "        eval_accumulation_steps=12,\n",
    "        report_to=\"none\",\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(model=model, args=args, eval_dataset=test_ds,\n",
    "                      tokenizer=tokenizer, data_collator=collator)\n",
    "\n",
    "    # 预测\n",
    "    output = trainer.predict(test_ds)\n",
    "    logits = output.predictions[0] if isinstance(output.predictions, (list, tuple)) else output.predictions\n",
    "    probs  = F.softmax(torch.tensor(logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    y_true = output.label_ids\n",
    "\n",
    "    # 指标（16类 + 4D）\n",
    "    pred_ids = logits.argmax(-1)\n",
    "    acc16 = float((pred_ids == y_true).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in y_true]\n",
    "\n",
    "    # 4D 准确率（与你原来一致）\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    # 4D 的二分类标签收集（0/1）\n",
    "    ei_t, ns_t, tf_t, jp_t = [], [], [], []\n",
    "    ei_p, ns_p, tf_p, jp_p = [], [], [], []\n",
    "\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        # 准确率计数\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "        # 记录二分类标签\n",
    "        ei_t.append(tei); ns_t.append(tns); tf_t.append(ttf); jp_t.append(tjp)\n",
    "        ei_p.append(pei); ns_p.append(pns); tf_p.append(ptf); jp_p.append(pjp)\n",
    "\n",
    "    n = len(y_true)\n",
    "\n",
    "    # ===== 新增：16类整体 F1 / Recall =====\n",
    "    f1_micro_16     = f1_score(y_true, pred_ids, average=\"micro\")\n",
    "    f1_macro_16     = f1_score(y_true, pred_ids, average=\"macro\")\n",
    "    f1_weighted_16  = f1_score(y_true, pred_ids, average=\"weighted\")\n",
    "\n",
    "    rec_micro_16    = recall_score(y_true, pred_ids, average=\"micro\")\n",
    "    rec_macro_16    = recall_score(y_true, pred_ids, average=\"macro\")\n",
    "    rec_weighted_16 = recall_score(y_true, pred_ids, average=\"weighted\")\n",
    "\n",
    "    # ===== 新增：四个维度的二分类 F1 / Recall（正类统一取 1，对应 E/N/F/P）=====\n",
    "    ei_f1  = f1_score(ei_t, ei_p, average=\"binary\", pos_label=1)\n",
    "    ns_f1  = f1_score(ns_t, ns_p, average=\"binary\", pos_label=1)\n",
    "    tf_f1  = f1_score(tf_t, tf_p, average=\"binary\", pos_label=1)\n",
    "    jp_f1  = f1_score(jp_t, jp_p, average=\"binary\", pos_label=1)\n",
    "\n",
    "    ei_rec = recall_score(ei_t, ei_p, average=\"binary\", pos_label=1)\n",
    "    ns_rec = recall_score(ns_t, ns_p, average=\"binary\", pos_label=1)\n",
    "    tf_rec = recall_score(tf_t, tf_p, average=\"binary\", pos_label=1)\n",
    "    jp_rec = recall_score(jp_t, jp_p, average=\"binary\", pos_label=1)\n",
    "\n",
    "    # ===== 新增：4D 的总体分数 =====\n",
    "    # micro：把四个维度的标签都拼接在一起计算\n",
    "    y4_true = np.concatenate([ei_t, ns_t, tf_t, jp_t])\n",
    "    y4_pred = np.concatenate([ei_p, ns_p, tf_p, jp_p])\n",
    "    f1_micro_4d  = f1_score(y4_true, y4_pred, average=\"binary\", pos_label=1)\n",
    "    rec_micro_4d = recall_score(y4_true, y4_pred, average=\"binary\", pos_label=1)\n",
    "\n",
    "    # macro：四个维度分数的平均\n",
    "    f1_macro_4d  = float(np.mean([ei_f1, ns_f1, tf_f1, jp_f1]))\n",
    "    rec_macro_4d = float(np.mean([ei_rec, ns_rec, tf_rec, jp_rec]))\n",
    "\n",
    "    os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "    plot_confusion_and_roc(y_true, probs, MBTI_16, OUTPUT_DIR, suffix=\"_test\")\n",
    "\n",
    "    print(\"\\n=== TEST Results ===\")\n",
    "    print(f\"acc_16: {acc16:.4f}\")\n",
    "    print(f\"acc_ei: {c_ei/n:.4f}  acc_ns: {c_ns/n:.4f}  acc_tf: {c_tf/n:.4f}  acc_jp: {c_jp/n:.4f}  acc_4D: {c_all/n:.4f}\")\n",
    "\n",
    "    # 16类总体\n",
    "    print(f\"F1-16(micro/macro/weighted): {f1_micro_16:.4f} / {f1_macro_16:.4f} / {f1_weighted_16:.4f}\")\n",
    "    print(f\"Recall-16(micro/macro/weighted): {rec_micro_16:.4f} / {rec_macro_16:.4f} / {rec_weighted_16:.4f}\")\n",
    "\n",
    "    # 4D 总体（把四个二分类合在一起的 micro，以及四维平均的 macro）\n",
    "    print(f\"F1-4D(overall micro/macro): {f1_micro_4d:.4f} / {f1_macro_4d:.4f}\")\n",
    "    print(f\"Recall-4D(overall micro/macro): {rec_micro_4d:.4f} / {rec_macro_4d:.4f}\")\n",
    "\n",
    "    # 4D 各维度\n",
    "    print(f\"[EI]  F1: {ei_f1:.4f}  Recall: {ei_rec:.4f}\")\n",
    "    print(f\"[NS]  F1: {ns_f1:.4f}  Recall: {ns_rec:.4f}\")\n",
    "    print(f\"[TF]  F1: {tf_f1:.4f}  Recall: {tf_rec:.4f}\")\n",
    "    print(f\"[JP]  F1: {jp_f1:.4f}  Recall: {jp_rec:.4f}\")\n",
    "\n",
    "    print(f\"Saved figs to: {OUTPUT_DIR}/confusion_matrix_test.png, {OUTPUT_DIR}/roc_micro_macro_test.png\")\n",
    "\n",
    "    # 推理示例\n",
    "    sample = test_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        plogits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(plogits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n[Inference on TEST sample]\")\n",
    "    print(\"原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5e212a43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== 每类数量（最多30，不够就全保）==\n",
      "INTJ: 30\n",
      "INTP: 30\n",
      "ENTJ: 30\n",
      "ENTP: 30\n",
      "INFJ: 30\n",
      "INFP: 30\n",
      "ENFJ: 10\n",
      "ENFP: 30\n",
      "ISTJ: 11\n",
      "ISFJ: 15\n",
      "ESTJ: 5\n",
      "ESFJ: 6\n",
      "ISTP: 30\n",
      "ISFP: 23\n",
      "ESTP: 10\n",
      "ESFP: 5\n",
      "Total picked: 325\n",
      "\n",
      "== 四维统计（左,右） & 与理想差距(各侧理想=总数/2) ==\n",
      "E/I (I=0,E=1): [199, 126]    偏差和=73.0\n",
      "N/S (S=0,N=1): [105, 220]    偏差和=115.0\n",
      "T/F (F=0,T=1): [149, 176]    偏差和=27.0\n",
      "J/P (P=0,J=1): [188, 137]    偏差和=51.0\n",
      "\n",
      "✅ 已保存：picked_30_per_type.json\n"
     ]
    }
   ],
   "source": [
    "# 改这里：输入/输出文件名 & 每类上限\n",
    "INPUT_JSON  = \"test对应的原始数据.json\"          # e.g., \"test对应的原始数据.json\"\n",
    "OUTPUT_JSON = \"picked_30_per_type.json\"\n",
    "PER_CLASS   = 30\n",
    "\n",
    "import json, hashlib\n",
    "from collections import defaultdict, Counter\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    m = (m or \"\").upper()\n",
    "    # I/S/F/P 记为0, E/N/T/J 记为1\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,  # EI\n",
    "        0 if m[1]==\"S\" else 1,  # NS\n",
    "        0 if m[2]==\"F\" else 1,  # TF\n",
    "        0 if m[3]==\"P\" else 1,  # JP\n",
    "    )\n",
    "\n",
    "def stable_key(ex):\n",
    "    \"\"\"type + 文本哈希，保证选择确定性；也用于去重\"\"\"\n",
    "    t = (ex.get(\"type\") or \"\").upper()\n",
    "    txt = ex.get(\"query_text\") or ex.get(\"posts_cleaned\") or ex.get(\"posts\") or \"\"\n",
    "    h = hashlib.sha1(txt.strip().lower().encode(\"utf-8\")).hexdigest()\n",
    "    return f\"{t}::{h}\"\n",
    "\n",
    "# 读取\n",
    "with open(INPUT_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# 过滤非法 & 去重（按 stable_key）\n",
    "items, seen = [], set()\n",
    "for ex in data:\n",
    "    t = (ex.get(\"type\") or \"\").upper()\n",
    "    if t not in MBTI2ID:\n",
    "        continue\n",
    "    ex[\"type\"] = t\n",
    "    key = stable_key(ex)\n",
    "    if key in seen:\n",
    "        continue\n",
    "    seen.add(key)\n",
    "    ex[\"_key\"] = key\n",
    "    items.append(ex)\n",
    "\n",
    "# 按类型分桶并确定性排序（键字典序）\n",
    "buckets = defaultdict(list)\n",
    "for ex in items:\n",
    "    buckets[ex[\"type\"]].append(ex)\n",
    "for t in buckets:\n",
    "    buckets[t].sort(key=lambda x: x[\"_key\"])\n",
    "\n",
    "# 每类最多取 PER_CLASS 条（不够就不够）\n",
    "picked = []\n",
    "for t in MBTI_16:\n",
    "    picked.extend(buckets.get(t, [])[:PER_CLASS])\n",
    "\n",
    "# 报告：16类分布\n",
    "type_dist = Counter(ex[\"type\"] for ex in picked)\n",
    "print(\"== 每类数量（最多30，不够就全保）==\")\n",
    "for t in MBTI_16:\n",
    "    print(f\"{t}: {type_dist.get(t,0)}\")\n",
    "total = len(picked)\n",
    "print(f\"Total picked: {total}\")\n",
    "\n",
    "# 报告：4维分布与偏差\n",
    "four = [[0,0],[0,0],[0,0],[0,0]]  # EI, NS, TF, JP\n",
    "for ex in picked:\n",
    "    b = mbti_to_4d(ex[\"type\"])\n",
    "    for d in range(4):\n",
    "        four[d][b[d]] += 1\n",
    "\n",
    "names = [\"E/I (I=0,E=1)\", \"N/S (S=0,N=1)\", \"T/F (F=0,T=1)\", \"J/P (P=0,J=1)\"]\n",
    "ideal = total / 2.0\n",
    "print(\"\\n== 四维统计（左,右） & 与理想差距(各侧理想=总数/2) ==\")\n",
    "for d, nm in enumerate(names):\n",
    "    left, right = four[d]\n",
    "    gap = abs(left - ideal) + abs(right - ideal)\n",
    "    print(f\"{nm}: {four[d]}    偏差和={gap:.1f}\")\n",
    "\n",
    "# 清理临时键并保存\n",
    "for ex in picked:\n",
    "    ex.pop(\"_key\", None)\n",
    "\n",
    "with open(OUTPUT_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(picked, f, ensure_ascii=False, indent=2)\n",
    "print(f\"\\n✅ 已保存：{OUTPUT_JSON}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bb09f8d8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "== 最终每类数量（≈30，允许±5）==\n",
      "INTJ: 25\n",
      "INTP: 25\n",
      "ENTJ: 31\n",
      "ENTP: 29\n",
      "INFJ: 35\n",
      "INFP: 25\n",
      "ENFJ: 10\n",
      "ENFP: 35\n",
      "ISTJ: 11\n",
      "ISFJ: 15\n",
      "ESTJ: 5\n",
      "ESFJ: 6\n",
      "ISTP: 35\n",
      "ISFP: 23\n",
      "ESTP: 10\n",
      "ESFP: 5\n",
      "Total picked: 325\n",
      "\n",
      "== 四维统计（左,右） & 与理想差距(各侧理想=总数/2) ==\n",
      "E/I (I=0,E=1): [194, 131]    偏差和=63.0\n",
      "N/S (S=0,N=1): [110, 215]    偏差和=105.0\n",
      "T/F (F=0,T=1): [154, 171]    偏差和=17.0\n",
      "J/P (P=0,J=1): [187, 138]    偏差和=49.0\n",
      "\n",
      "✅ 已保存：picked_balanced_around30.json\n"
     ]
    }
   ],
   "source": [
    "# ======= 配置（改这里） =======\n",
    "INPUT_JSON  = \"test对应的原始数据.json\"   # 你的输入数据\n",
    "OUTPUT_JSON = \"picked_balanced_around30.json\"\n",
    "PER_CLASS   = 30      # 目标每类数量（中心值）\n",
    "WIGGLE      = 5       # 允许上下浮动范围：即每类 ∈ [PER_CLASS-WIGGLE, PER_CLASS+WIGGLE]\n",
    "# ============================\n",
    "\n",
    "import json, hashlib, copy\n",
    "from collections import defaultdict, Counter\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_bits(m: str):\n",
    "    m = m.upper()\n",
    "    # I/S/F/P 记为0, E/N/T/J 记为1\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,  # EI\n",
    "        0 if m[1]==\"S\" else 1,  # NS\n",
    "        0 if m[2]==\"F\" else 1,  # TF\n",
    "        0 if m[3]==\"P\" else 1,  # JP\n",
    "    )\n",
    "\n",
    "# 确定性键：type + 文本哈希（用于排序与去重）\n",
    "def stable_key(ex):\n",
    "    t = (ex.get(\"type\") or \"\").upper()\n",
    "    txt = ex.get(\"query_text\") or ex.get(\"posts_cleaned\") or ex.get(\"posts\") or \"\"\n",
    "    h = hashlib.sha1(txt.strip().lower().encode(\"utf-8\")).hexdigest()\n",
    "    return f\"{t}::{h}\"\n",
    "\n",
    "# 读取并去重\n",
    "with open(INPUT_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "    raw = json.load(f)\n",
    "\n",
    "items, seen = [], set()\n",
    "for ex in raw:\n",
    "    t = (ex.get(\"type\") or \"\").upper()\n",
    "    if t not in MBTI2ID:\n",
    "        continue\n",
    "    ex[\"type\"] = t\n",
    "    key = stable_key(ex)\n",
    "    if key in seen:\n",
    "        continue\n",
    "    seen.add(key)\n",
    "    ex[\"_key\"] = key\n",
    "    items.append(ex)\n",
    "\n",
    "# 分桶并确定性排序\n",
    "buckets = defaultdict(list)\n",
    "for ex in items:\n",
    "    buckets[ex[\"type\"]].append(ex)\n",
    "for t in buckets:\n",
    "    buckets[t].sort(key=lambda x: x[\"_key\"])  # 稳定顺序\n",
    "\n",
    "# 统计每类可用数 & 位向量\n",
    "avail = {t: len(buckets.get(t, [])) for t in MBTI_16}\n",
    "bits  = {t: mbti_bits(t) for t in MBTI_16}\n",
    "\n",
    "# 初始配额：先取 min(PER_CLASS, avail)（不足的全保留）\n",
    "k = {t: min(PER_CLASS, avail[t]) for t in MBTI_16}\n",
    "\n",
    "# 每类上下限：不足的类下限=其可用数（不减少）；其余允许在 [PER_CLASS-WIGGLE, PER_CLASS+WIGGLE]\n",
    "min_k = {}\n",
    "max_k = {}\n",
    "for t in MBTI_16:\n",
    "    if avail[t] < PER_CLASS:   # 稀有/不足类：全保留，不减少\n",
    "        min_k[t] = avail[t]\n",
    "        max_k[t] = avail[t]    # 也不增加（没有更多）\n",
    "        k[t]     = avail[t]\n",
    "    else:\n",
    "        min_k[t] = max(0, PER_CLASS - WIGGLE)\n",
    "        max_k[t] = min(avail[t], PER_CLASS + WIGGLE)\n",
    "        k[t]     = max(min_k[t], min(k[t], max_k[t]))\n",
    "\n",
    "# 目标函数：四维平方偏差（越小越均衡）\n",
    "def objective(kdict):\n",
    "    total = sum(kdict.values())\n",
    "    if total == 0:\n",
    "        return 0.0\n",
    "    # side1 是 E/N/T/J 的数量； side0 = total - side1\n",
    "    obj = 0.0\n",
    "    for d in range(4):\n",
    "        side1 = sum(kdict[t] for t in MBTI_16 if bits[t][d]==1)\n",
    "        obj += (side1 - total/2.0)**2\n",
    "    return obj\n",
    "\n",
    "# 贪心交换：每次尝试把一个“过多侧”的类型 -1，和一个“过少侧”的类型 +1（均在上下限内）\n",
    "def rebalance(k):\n",
    "    k = k.copy()\n",
    "    best = objective(k)\n",
    "    improved = True\n",
    "    iters = 0\n",
    "    while improved:\n",
    "        improved = False\n",
    "        iters += 1\n",
    "        total = sum(k.values())\n",
    "        if total == 0: break\n",
    "\n",
    "        # 当前四维的 side1 数量 & 失衡度\n",
    "        side1 = [sum(k[t] for t in MBTI_16 if bits[t][d]==1) for d in range(4)]\n",
    "        # 对每个维度，确定“过多侧”与“过少侧” (1侧与0侧)\n",
    "        over_under = []\n",
    "        for d in range(4):\n",
    "            over_side = 1 if side1[d] > total/2.0 else 0\n",
    "            under_side = 1 - over_side\n",
    "            gap = abs(side1[d] - total/2.0)\n",
    "            over_under.append((gap, d, over_side, under_side))\n",
    "        # 按 gap 从大到小尝试修正\n",
    "        over_under.sort(reverse=True)\n",
    "\n",
    "        for gap, d, over_side, under_side in over_under:\n",
    "            if gap <= 0:  # 已经均衡\n",
    "                continue\n",
    "            # 候选可以 -1 的类型：在过多侧、且 k[t] > min_k[t]\n",
    "            cands_down = [t for t in MBTI_16 if bits[t][d]==over_side and k[t] > min_k[t]]\n",
    "            # 候选可以 +1 的类型：在过少侧、且 k[t] < max_k[t]\n",
    "            cands_up   = [t for t in MBTI_16 if bits[t][d]==under_side and k[t] < max_k[t]]\n",
    "            if not cands_down or not cands_up:\n",
    "                continue\n",
    "\n",
    "            # 穷举所有 (down, up) 组合（16*16 最多 256 种），找最优改进\n",
    "            local_best_impr = 0.0\n",
    "            local_best_pair = None\n",
    "            for t_down in cands_down:\n",
    "                for t_up in cands_up:\n",
    "                    if t_down == t_up: \n",
    "                        continue\n",
    "                    k_try = k.copy()\n",
    "                    k_try[t_down] -= 1\n",
    "                    k_try[t_up]   += 1\n",
    "                    new_obj = objective(k_try)\n",
    "                    impr = best - new_obj\n",
    "                    # 二级偏好：尽量贴近 PER_CLASS（让各类“30左右”）\n",
    "                    # 若改进相同，则优先让 |k-30| 更小\n",
    "                    if impr > local_best_impr + 1e-9:\n",
    "                        local_best_impr = impr\n",
    "                        local_best_pair = (t_down, t_up, new_obj)\n",
    "                    elif abs(impr - local_best_impr) <= 1e-9 and local_best_pair is not None:\n",
    "                        old_dev = abs(k[local_best_pair[0]]-PER_CLASS)+abs(k[local_best_pair[1]]-PER_CLASS)\n",
    "                        new_dev = abs((k[t_down]-1)-PER_CLASS)+abs((k[t_up]+1)-PER_CLASS)\n",
    "                        if new_dev < old_dev:\n",
    "                            local_best_pair = (t_down, t_up, new_obj)\n",
    "\n",
    "            if local_best_pair is not None and local_best_impr > 1e-9:\n",
    "                t_down, t_up, new_obj = local_best_pair\n",
    "                k[t_down] -= 1\n",
    "                k[t_up]   += 1\n",
    "                best = new_obj\n",
    "                improved = True\n",
    "                break  # 先应用一次改进，再重新评估四维\n",
    "        if iters > 2000:  # 安全退出\n",
    "            break\n",
    "    return k\n",
    "\n",
    "k_bal = rebalance(k)\n",
    "\n",
    "# 依据最终配额，按确定性顺序取样\n",
    "picked = []\n",
    "for t in MBTI_16:\n",
    "    bucket = buckets.get(t, [])\n",
    "    picked.extend(bucket[:k_bal.get(t, 0)])\n",
    "\n",
    "# 报告\n",
    "dist = Counter(ex[\"type\"] for ex in picked)\n",
    "print(\"== 最终每类数量（≈30，允许±%d）==\" % WIGGLE)\n",
    "for t in MBTI_16:\n",
    "    print(f\"{t}: {dist.get(t,0)}\")\n",
    "total = len(picked)\n",
    "print(f\"Total picked: {total}\")\n",
    "\n",
    "# 四维统计\n",
    "def four_stats(ex_list):\n",
    "    four = [[0,0],[0,0],[0,0],[0,0]]  # EI, NS, TF, JP\n",
    "    for ex in ex_list:\n",
    "        b = mbti_bits(ex[\"type\"])\n",
    "        for d in range(4):\n",
    "            four[d][b[d]] += 1\n",
    "    return four\n",
    "\n",
    "four = four_stats(picked)\n",
    "names = [\"E/I (I=0,E=1)\", \"N/S (S=0,N=1)\", \"T/F (F=0,T=1)\", \"J/P (P=0,J=1)\"]\n",
    "ideal = total / 2.0\n",
    "print(\"\\n== 四维统计（左,右） & 与理想差距(各侧理想=总数/2) ==\")\n",
    "for d, nm in enumerate(names):\n",
    "    left, right = four[d]\n",
    "    gap = abs(left - ideal) + abs(right - ideal)\n",
    "    print(f\"{nm}: {four[d]}    偏差和={gap:.1f}\")\n",
    "\n",
    "# 保存\n",
    "for ex in picked:\n",
    "    ex.pop(\"_key\", None)\n",
    "with open(OUTPUT_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(picked, f, ensure_ascii=False, indent=2)\n",
    "print(f\"\\n✅ 已保存：{OUTPUT_JSON}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
