{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6007b949",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import json\n",
    "# import torch\n",
    "# import torch.nn as nn\n",
    "# import torch.nn.functional as F\n",
    "# from torch.utils.data import Dataset, DataLoader\n",
    "# from transformers import AutoTokenizer, AutoModel\n",
    "# from tqdm import tqdm\n",
    "# from sklearn.model_selection import train_test_split\n",
    "# from sklearn.manifold import TSNE\n",
    "# import matplotlib.pyplot as plt\n",
    "# import pandas as pd\n",
    "\n",
    "# # ====== 配置 ======\n",
    "# MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "# MAX_LEN = 256\n",
    "# BATCH_SIZE = 4\n",
    "# EPOCHS = 5\n",
    "# LR = 2e-5\n",
    "# LAMBDA_CL = 1.0   # 对比学习权重\n",
    "# SAVE_PATH = \"best_deberta_contrastive.pt\"\n",
    "# DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# # ====== 工具函数 ======\n",
    "# def mbti_to_binary(mbti_type):\n",
    "#     return {\n",
    "#         \"ei\": 0 if mbti_type[0] == \"I\" else 1,\n",
    "#         \"ns\": 0 if mbti_type[1] == \"S\" else 1,\n",
    "#         \"tf\": 0 if mbti_type[2] == \"F\" else 1,\n",
    "#         \"jp\": 0 if mbti_type[3] == \"P\" else 1\n",
    "#     }\n",
    "\n",
    "# # ====== Dataset ======\n",
    "# class JointMBTIDataset(Dataset):\n",
    "#     def __init__(self, data, tokenizer, max_length=256):\n",
    "#         self.data = data\n",
    "#         self.tokenizer = tokenizer\n",
    "#         self.max_length = max_length\n",
    "\n",
    "#     def _encode(self, text):\n",
    "#         return self.tokenizer(\n",
    "#             text,\n",
    "#             padding=\"max_length\",\n",
    "#             truncation=True,\n",
    "#             max_length=self.max_length,\n",
    "#             return_tensors=\"pt\"\n",
    "#         )\n",
    "\n",
    "#     def __len__(self):\n",
    "#         return len(self.data)\n",
    "\n",
    "#     def __getitem__(self, idx):\n",
    "#         item = self.data[idx]\n",
    "#         base_text = item.get(\"posts_cleaned\", item[\"posts\"]) or \"\"\n",
    "#         semantic = item.get(\"semantic_view\", \"\")\n",
    "#         sentiment = item.get(\"sentiment_view\", \"\")\n",
    "#         linguistic = item.get(\"linguistic_view\", \"\")\n",
    "\n",
    "#         mbti_label = mbti_to_binary(item[\"type\"])\n",
    "\n",
    "#         enc_base = self._encode(base_text)\n",
    "#         enc_sem = self._encode(semantic or base_text)\n",
    "#         enc_sent = self._encode(sentiment or base_text)\n",
    "#         enc_ling = self._encode(linguistic or base_text)\n",
    "\n",
    "#         return {\n",
    "#             # 原文\n",
    "#             \"input_ids\": enc_base[\"input_ids\"].squeeze(0),\n",
    "#             \"attention_mask\": enc_base[\"attention_mask\"].squeeze(0),\n",
    "#             # 三视角\n",
    "#             \"semantic_ids\": enc_sem[\"input_ids\"].squeeze(0),\n",
    "#             \"semantic_mask\": enc_sem[\"attention_mask\"].squeeze(0),\n",
    "#             \"sentiment_ids\": enc_sent[\"input_ids\"].squeeze(0),\n",
    "#             \"sentiment_mask\": enc_sent[\"attention_mask\"].squeeze(0),\n",
    "#             \"linguistic_ids\": enc_ling[\"input_ids\"].squeeze(0),\n",
    "#             \"linguistic_mask\": enc_ling[\"attention_mask\"].squeeze(0),\n",
    "#             # 标签\n",
    "#             \"ei\": torch.tensor(mbti_label[\"ei\"], dtype=torch.float),\n",
    "#             \"ns\": torch.tensor(mbti_label[\"ns\"], dtype=torch.float),\n",
    "#             \"tf\": torch.tensor(mbti_label[\"tf\"], dtype=torch.float),\n",
    "#             \"jp\": torch.tensor(mbti_label[\"jp\"], dtype=torch.float),\n",
    "#             \"label\": item[\"type\"]  # 用于可视化\n",
    "#         }\n",
    "\n",
    "# # ====== 模型 ======\n",
    "# class JointMBTIModel(nn.Module):\n",
    "#     def __init__(self, encoder_name=MODEL_NAME, dropout=0.1):\n",
    "#         super().__init__()\n",
    "#         self.encoder = AutoModel.from_pretrained(encoder_name)\n",
    "#         hidden_size = self.encoder.config.hidden_size\n",
    "#         self.dropout = nn.Dropout(dropout)\n",
    "#         self.classifier_ei = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_ns = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_tf = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_jp = nn.Linear(hidden_size, 1)\n",
    "\n",
    "#     def get_embedding(self, input_ids, attention_mask):\n",
    "#         out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)\n",
    "#         cls_emb = self.dropout(out.last_hidden_state[:, 0, :])\n",
    "#         return cls_emb\n",
    "\n",
    "#     def forward(self, input_ids, attention_mask):\n",
    "#         cls_emb = self.get_embedding(input_ids, attention_mask)\n",
    "#         return {\n",
    "#             \"embedding\": cls_emb,\n",
    "#             \"ei\": self.classifier_ei(cls_emb),\n",
    "#             \"ns\": self.classifier_ns(cls_emb),\n",
    "#             \"tf\": self.classifier_tf(cls_emb),\n",
    "#             \"jp\": self.classifier_jp(cls_emb)\n",
    "#         }\n",
    "\n",
    "# # ====== 损失函数 ======\n",
    "# def compute_mbti_loss(preds, labels):\n",
    "#     bce = nn.BCEWithLogitsLoss()\n",
    "#     return (\n",
    "#         bce(preds[\"ei\"].squeeze(), labels[\"ei\"]) +\n",
    "#         bce(preds[\"ns\"].squeeze(), labels[\"ns\"]) +\n",
    "#         bce(preds[\"tf\"].squeeze(), labels[\"tf\"]) +\n",
    "#         bce(preds[\"jp\"].squeeze(), labels[\"jp\"])\n",
    "#     ) / 4\n",
    "\n",
    "# def contrastive_loss(anchor, positives, temperature=0.07):\n",
    "#     anchor = F.normalize(anchor, dim=1)\n",
    "#     losses = []\n",
    "#     for pos in positives:\n",
    "#         pos = F.normalize(pos, dim=1)\n",
    "#         logits = torch.matmul(anchor, pos.T) / temperature\n",
    "#         labels = torch.arange(anchor.size(0), device=anchor.device)\n",
    "#         losses.append(F.cross_entropy(logits, labels))\n",
    "#     return sum(losses) / len(losses)\n",
    "\n",
    "# # ====== 训练/评估 ======\n",
    "# def train_model(model, train_loader, val_loader, optimizer, device, epochs=5, save_path=\"best.pt\"):\n",
    "#     model.to(device)\n",
    "#     best_val_acc_all = 0.0\n",
    "\n",
    "#     for epoch in range(epochs):\n",
    "#         model.train()\n",
    "#         total_loss = 0\n",
    "\n",
    "#         for batch in tqdm(train_loader, desc=f\"[Train] Epoch {epoch+1}/{epochs}\"):\n",
    "#             # 输入\n",
    "#             input_ids = batch[\"input_ids\"].to(device)\n",
    "#             attention_mask = batch[\"attention_mask\"].to(device)\n",
    "#             semantic_ids = batch[\"semantic_ids\"].to(device)\n",
    "#             semantic_mask = batch[\"semantic_mask\"].to(device)\n",
    "#             sentiment_ids = batch[\"sentiment_ids\"].to(device)\n",
    "#             sentiment_mask = batch[\"sentiment_mask\"].to(device)\n",
    "#             linguistic_ids = batch[\"linguistic_ids\"].to(device)\n",
    "#             linguistic_mask = batch[\"linguistic_mask\"].to(device)\n",
    "\n",
    "#             labels = {k: batch[k].to(device) for k in [\"ei\", \"ns\", \"tf\", \"jp\"]}\n",
    "\n",
    "#             optimizer.zero_grad()\n",
    "\n",
    "#             # 分类损失\n",
    "#             out_base = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "#             loss_cls = compute_mbti_loss(out_base, labels)\n",
    "\n",
    "#             # 对比损失\n",
    "#             emb_base = out_base[\"embedding\"]\n",
    "#             emb_sem = model.get_embedding(semantic_ids, semantic_mask)\n",
    "#             emb_sent = model.get_embedding(sentiment_ids, sentiment_mask)\n",
    "#             emb_ling = model.get_embedding(linguistic_ids, linguistic_mask)\n",
    "#             loss_cl = contrastive_loss(emb_base, [emb_sem, emb_sent, emb_ling])\n",
    "\n",
    "#             loss = loss_cls + LAMBDA_CL * loss_cl\n",
    "#             loss.backward()\n",
    "#             optimizer.step()\n",
    "#             total_loss += loss.item()\n",
    "\n",
    "#         avg_loss = total_loss / len(train_loader)\n",
    "#         print(f\"\\nEpoch {epoch+1} - Train Loss: {avg_loss:.4f}\")\n",
    "\n",
    "#         # 验证（返回各维准确率 + 4D）\n",
    "#         metrics = evaluate(model, val_loader, device)\n",
    "#         print(\n",
    "#             \"Validation Accuracy | \"\n",
    "#             f\"EI: {metrics['acc_ei']:.2%}  \"\n",
    "#             f\"NS: {metrics['acc_ns']:.2%}  \"\n",
    "#             f\"TF: {metrics['acc_tf']:.2%}  \"\n",
    "#             f\"JP: {metrics['acc_jp']:.2%}  \"\n",
    "#             f\"4D: {metrics['acc_all']:.2%}\"\n",
    "#         )\n",
    "\n",
    "#         # 保存最优（按 4D match）\n",
    "#         if metrics[\"acc_all\"] > best_val_acc_all:\n",
    "#             best_val_acc_all = metrics[\"acc_all\"]\n",
    "#             torch.save(model.state_dict(), save_path)\n",
    "#             print(f\"✅ Best model saved to {save_path} (4D match: {best_val_acc_all:.2%})\")\n",
    "\n",
    "# def evaluate(model, dataloader, device):\n",
    "#     model.eval()\n",
    "#     total = 0\n",
    "#     correct_ei = correct_ns = correct_tf = correct_jp = correct_all = 0\n",
    "\n",
    "#     with torch.no_grad():\n",
    "#         for batch in dataloader:\n",
    "#             input_ids = batch[\"input_ids\"].to(device)\n",
    "#             attention_mask = batch[\"attention_mask\"].to(device)\n",
    "#             labels = {k: batch[k].to(device) for k in [\"ei\", \"ns\", \"tf\", \"jp\"]}\n",
    "\n",
    "#             outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "#             pred_ei = torch.sigmoid(outputs[\"ei\"]).round().squeeze()\n",
    "#             pred_ns = torch.sigmoid(outputs[\"ns\"]).round().squeeze()\n",
    "#             pred_tf = torch.sigmoid(outputs[\"tf\"]).round().squeeze()\n",
    "#             pred_jp = torch.sigmoid(outputs[\"jp\"]).round().squeeze()\n",
    "\n",
    "#             correct_ei += (pred_ei == labels[\"ei\"]).sum().item()\n",
    "#             correct_ns += (pred_ns == labels[\"ns\"]).sum().item()\n",
    "#             correct_tf += (pred_tf == labels[\"tf\"]).sum().item()\n",
    "#             correct_jp += (pred_jp == labels[\"jp\"]).sum().item()\n",
    "\n",
    "#             correct_all += (\n",
    "#                 (pred_ei == labels[\"ei\"]) &\n",
    "#                 (pred_ns == labels[\"ns\"]) &\n",
    "#                 (pred_tf == labels[\"tf\"]) &\n",
    "#                 (pred_jp == labels[\"jp\"])\n",
    "#             ).sum().item()\n",
    "\n",
    "#             total += input_ids.size(0)\n",
    "\n",
    "#     acc_ei = correct_ei / total\n",
    "#     acc_ns = correct_ns / total\n",
    "#     acc_tf = correct_tf / total\n",
    "#     acc_jp = correct_jp / total\n",
    "#     acc_all = correct_all / total\n",
    "\n",
    "#     return {\n",
    "#         \"acc_ei\": acc_ei,\n",
    "#         \"acc_ns\": acc_ns,\n",
    "#         \"acc_tf\": acc_tf,\n",
    "#         \"acc_jp\": acc_jp,\n",
    "#         \"acc_all\": acc_all\n",
    "#     }\n",
    "\n",
    "# # ====== 可视化（t-SNE）======\n",
    "# @torch.no_grad()\n",
    "# def visualize_embeddings(model, dataloader, device, title=\"MBTI Embedding Space (t-SNE)\"):\n",
    "#     model.eval()\n",
    "#     all_embeds, all_labels = [], []\n",
    "\n",
    "#     for batch in dataloader:\n",
    "#         input_ids = batch[\"input_ids\"].to(device)\n",
    "#         attention_mask = batch[\"attention_mask\"].to(device)\n",
    "#         embs = model.get_embedding(input_ids, attention_mask)\n",
    "#         all_embeds.append(embs.cpu())\n",
    "#         all_labels.extend(batch[\"label\"])  # MBTI 字符串类型\n",
    "\n",
    "#     embeddings = torch.cat(all_embeds, dim=0).numpy()\n",
    "#     tsne = TSNE(n_components=2, random_state=42, init=\"pca\", learning_rate=\"auto\")\n",
    "#     xy = tsne.fit_transform(embeddings)\n",
    "\n",
    "#     df = pd.DataFrame(xy, columns=[\"x\", \"y\"])\n",
    "#     df[\"label\"] = all_labels\n",
    "\n",
    "#     plt.figure(figsize=(10, 8))\n",
    "#     for label in sorted(set(all_labels)):\n",
    "#         pts = df[df[\"label\"] == label]\n",
    "#         plt.scatter(pts[\"x\"], pts[\"y\"], alpha=0.65, s=18, label=label)\n",
    "#     plt.legend(markerscale=1.5, bbox_to_anchor=(1.05, 1), loc=\"upper left\", ncol=1)\n",
    "#     plt.title(title)\n",
    "#     plt.tight_layout()\n",
    "#     plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "24c43577",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if __name__ == \"__main__\":\n",
    "#     with open(\"mbti_sample_with_all_views.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "#         full_data = json.load(f)\n",
    "\n",
    "#     train_data, val_data = train_test_split(full_data, test_size=0.2, random_state=42)\n",
    "\n",
    "#     tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "#     train_dataset = JointMBTIDataset(train_data, tokenizer, max_length=MAX_LEN)\n",
    "#     val_dataset = JointMBTIDataset(val_data, tokenizer, max_length=MAX_LEN)\n",
    "\n",
    "#     train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "#     val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)\n",
    "\n",
    "#     model = JointMBTIModel(MODEL_NAME)\n",
    "#     optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
    "\n",
    "#     # 训练\n",
    "#     train_model(model, train_loader, val_loader, optimizer, DEVICE, epochs=50, save_path=SAVE_PATH)\n",
    "\n",
    "#     # 可视化（用验证集）\n",
    "#     print(\"\\n📊 Visualizing embeddings on the validation set...\")\n",
    "#     visualize_embeddings(model, val_loader, DEVICE, title=\"MBTI CLS Embedding Space (t-SNE)\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da7385a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "#TYB模型\n",
    "# import json\n",
    "# import torch\n",
    "# import torch.nn as nn\n",
    "# import torch.nn.functional as F\n",
    "# from torch.utils.data import Dataset, DataLoader\n",
    "# from transformers import AutoTokenizer, AutoModel\n",
    "# from tqdm import tqdm\n",
    "# from sklearn.manifold import TSNE\n",
    "# import matplotlib.pyplot as plt\n",
    "# import pandas as pd\n",
    "# import random\n",
    "\n",
    "# # ====== 配置 ======\n",
    "# MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "# MAX_LEN = 256\n",
    "# BATCH_SIZE = 4\n",
    "# EPOCHS = 30\n",
    "# LR = 2e-5\n",
    "# LAMBDA_CL = 0.1   # 对比学习权重\n",
    "# SAVE_PATH = \"best_deberta_supcon.pt\"\n",
    "# DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# # ====== 工具函数 ======\n",
    "# def mbti_to_binary(mbti_type):\n",
    "#     return {\n",
    "#         \"ei\": 0 if mbti_type[0] == \"I\" else 1,\n",
    "#         \"ns\": 0 if mbti_type[1] == \"S\" else 1,\n",
    "#         \"tf\": 0 if mbti_type[2] == \"F\" else 1,\n",
    "#         \"jp\": 0 if mbti_type[3] == \"P\" else 1\n",
    "#     }\n",
    "\n",
    "# def mbti_to_classid(mbti_type):\n",
    "#     \"\"\"将16种MBTI映射为0-15整数标签\"\"\"\n",
    "#     types = [\n",
    "#         \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\n",
    "#         \"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "#         \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\n",
    "#         \"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "#     ]\n",
    "#     return types.index(mbti_type)\n",
    "\n",
    "# def random_drop_words(text, drop_prob=0.1):\n",
    "#     \"\"\"随机丢弃部分token，防止过拟合explain提示词\"\"\"\n",
    "#     words = text.split()\n",
    "#     keep = [w for w in words if random.random() > drop_prob]\n",
    "#     return \" \".join(keep) if keep else text\n",
    "\n",
    "# # ====== Dataset ======\n",
    "# class JointMBTIDataset(Dataset):\n",
    "#     def __init__(self, data, tokenizer, max_length=256, drop_prob=0.1):\n",
    "#         self.data = data\n",
    "#         self.tokenizer = tokenizer\n",
    "#         self.max_length = max_length\n",
    "#         self.drop_prob = drop_prob\n",
    "\n",
    "#     def _encode(self, text):\n",
    "#         return self.tokenizer(\n",
    "#             text,\n",
    "#             padding=\"max_length\",\n",
    "#             truncation=True,\n",
    "#             max_length=self.max_length,\n",
    "#             return_tensors=\"pt\"\n",
    "#         )\n",
    "\n",
    "#     def __len__(self):\n",
    "#         return len(self.data)\n",
    "\n",
    "#     def __getitem__(self, idx):\n",
    "#         item = self.data[idx]\n",
    "#         base_text = item.get(\"posts_cleaned\", item[\"posts\"]) or \"\"\n",
    "\n",
    "#         # 三视角（加随机DropWord）\n",
    "#         semantic = random_drop_words(item.get(\"semantic_view\", \"\"), self.drop_prob)\n",
    "#         sentiment = random_drop_words(item.get(\"sentiment_view\", \"\"), self.drop_prob)\n",
    "#         linguistic = random_drop_words(item.get(\"linguistic_view\", \"\"), self.drop_prob)\n",
    "\n",
    "#         mbti_label = mbti_to_binary(item[\"type\"])\n",
    "#         class_id = mbti_to_classid(item[\"type\"])\n",
    "\n",
    "#         enc_base = self._encode(base_text)\n",
    "#         enc_sem = self._encode(semantic or base_text)\n",
    "#         enc_sent = self._encode(sentiment or base_text)\n",
    "#         enc_ling = self._encode(linguistic or base_text)\n",
    "\n",
    "#         return {\n",
    "#             # 主文本\n",
    "#             \"input_ids\": enc_base[\"input_ids\"].squeeze(0),\n",
    "#             \"attention_mask\": enc_base[\"attention_mask\"].squeeze(0),\n",
    "#             # 三视角\n",
    "#             \"semantic_ids\": enc_sem[\"input_ids\"].squeeze(0),\n",
    "#             \"semantic_mask\": enc_sem[\"attention_mask\"].squeeze(0),\n",
    "#             \"sentiment_ids\": enc_sent[\"input_ids\"].squeeze(0),\n",
    "#             \"sentiment_mask\": enc_sent[\"attention_mask\"].squeeze(0),\n",
    "#             \"linguistic_ids\": enc_ling[\"input_ids\"].squeeze(0),\n",
    "#             \"linguistic_mask\": enc_ling[\"attention_mask\"].squeeze(0),\n",
    "#             # 二进制标签\n",
    "#             \"ei\": torch.tensor(mbti_label[\"ei\"], dtype=torch.float),\n",
    "#             \"ns\": torch.tensor(mbti_label[\"ns\"], dtype=torch.float),\n",
    "#             \"tf\": torch.tensor(mbti_label[\"tf\"], dtype=torch.float),\n",
    "#             \"jp\": torch.tensor(mbti_label[\"jp\"], dtype=torch.float),\n",
    "#             # 分类id\n",
    "#             \"class_id\": torch.tensor(class_id, dtype=torch.long),\n",
    "#             \"label\": item[\"type\"]\n",
    "#         }\n",
    "\n",
    "# # ====== 模型 ======\n",
    "# class JointMBTIModel(nn.Module):\n",
    "#     def __init__(self, encoder_name=MODEL_NAME, dropout=0.3):\n",
    "#         super().__init__()\n",
    "#         self.encoder = AutoModel.from_pretrained(encoder_name)\n",
    "#         hidden_size = self.encoder.config.hidden_size\n",
    "#         self.dropout = nn.Dropout(dropout)\n",
    "#         self.classifier_ei = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_ns = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_tf = nn.Linear(hidden_size, 1)\n",
    "#         self.classifier_jp = nn.Linear(hidden_size, 1)\n",
    "\n",
    "#     def get_embedding(self, input_ids, attention_mask):\n",
    "#         out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)\n",
    "#         cls_emb = self.dropout(out.last_hidden_state[:, 0, :])\n",
    "#         return cls_emb\n",
    "\n",
    "#     def forward(self, input_ids, attention_mask):\n",
    "#         cls_emb = self.get_embedding(input_ids, attention_mask)\n",
    "#         return {\n",
    "#             \"embedding\": cls_emb,\n",
    "#             \"ei\": self.classifier_ei(cls_emb),\n",
    "#             \"ns\": self.classifier_ns(cls_emb),\n",
    "#             \"tf\": self.classifier_tf(cls_emb),\n",
    "#             \"jp\": self.classifier_jp(cls_emb)\n",
    "#         }\n",
    "\n",
    "# # ====== 损失函数 ======\n",
    "# def compute_mbti_loss(preds, labels):\n",
    "#     bce = nn.BCEWithLogitsLoss()\n",
    "#     return (\n",
    "#         bce(preds[\"ei\"].squeeze(), labels[\"ei\"]) +\n",
    "#         bce(preds[\"ns\"].squeeze(), labels[\"ns\"]) +\n",
    "#         bce(preds[\"tf\"].squeeze(), labels[\"tf\"]) +\n",
    "#         bce(preds[\"jp\"].squeeze(), labels[\"jp\"])\n",
    "#     ) / 4\n",
    "\n",
    "# def supervised_contrastive_loss(embeddings, labels, temperature=0.07):\n",
    "#     \"\"\"Supervised Contrastive Loss\"\"\"\n",
    "#     device = embeddings.device\n",
    "#     embeddings = F.normalize(embeddings, dim=1)\n",
    "#     similarity = torch.matmul(embeddings, embeddings.T) / temperature\n",
    "\n",
    "#     mask = torch.eye(len(labels), dtype=torch.bool, device=device)\n",
    "#     labels = labels.contiguous().view(-1, 1)\n",
    "#     match_mask = torch.eq(labels, labels.T).to(device) & ~mask\n",
    "\n",
    "#     logits_max, _ = torch.max(similarity, dim=1, keepdim=True)\n",
    "#     logits = similarity - logits_max.detach()\n",
    "#     exp_logits = torch.exp(logits) * (~mask)\n",
    "#     log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))\n",
    "\n",
    "#     mean_log_prob_pos = (match_mask * log_prob).sum(1) / match_mask.sum(1).clamp(min=1)\n",
    "#     loss = -mean_log_prob_pos.mean()\n",
    "#     return loss\n",
    "\n",
    "# # ====== 训练 ======\n",
    "# def train_model(model, train_loader, val_loader, optimizer, device, epochs=5, save_path=\"best.pt\"):\n",
    "#     model.to(device)\n",
    "#     best_val_acc_all = 0.0\n",
    "\n",
    "#     for epoch in range(epochs):\n",
    "#         model.train()\n",
    "#         total_loss = 0\n",
    "\n",
    "#         for batch in tqdm(train_loader, desc=f\"[Train] Epoch {epoch+1}/{epochs}\"):\n",
    "#             optimizer.zero_grad()\n",
    "\n",
    "#             # 主文本\n",
    "#             out_base = model(batch[\"input_ids\"].to(device), batch[\"attention_mask\"].to(device))\n",
    "#             loss_cls = compute_mbti_loss(out_base, {k: batch[k].to(device) for k in [\"ei\",\"ns\",\"tf\",\"jp\"]})\n",
    "\n",
    "#             # 三视角 embedding\n",
    "#             emb_list = [out_base[\"embedding\"]]\n",
    "#             label_list = [batch[\"class_id\"].to(device)]\n",
    "\n",
    "#             for view in [\"semantic\", \"sentiment\", \"linguistic\"]:\n",
    "#                 emb = model.get_embedding(batch[f\"{view}_ids\"].to(device), batch[f\"{view}_mask\"].to(device))\n",
    "#                 emb_list.append(emb)\n",
    "#                 label_list.append(batch[\"class_id\"].to(device))\n",
    "\n",
    "#             # 拼接所有 embedding\n",
    "#             all_embeddings = torch.cat(emb_list, dim=0)\n",
    "#             all_labels = torch.cat(label_list, dim=0)\n",
    "\n",
    "#             loss_cl = supervised_contrastive_loss(all_embeddings, all_labels)\n",
    "#             loss = loss_cls + LAMBDA_CL * loss_cl\n",
    "\n",
    "#             loss.backward()\n",
    "#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "#             optimizer.step()\n",
    "#             total_loss += loss.item()\n",
    "\n",
    "#         avg_loss = total_loss / len(train_loader)\n",
    "#         print(f\"\\nEpoch {epoch+1} - Train Loss: {avg_loss:.4f}\")\n",
    "\n",
    "#         metrics = evaluate(model, val_loader, device)\n",
    "#         print(f\"Validation | EI: {metrics['acc_ei']:.2%} NS: {metrics['acc_ns']:.2%} TF: {metrics['acc_tf']:.2%} JP: {metrics['acc_jp']:.2%} 4D: {metrics['acc_all']:.2%}\")\n",
    "\n",
    "#         if metrics[\"acc_all\"] > best_val_acc_all:\n",
    "#             best_val_acc_all = metrics[\"acc_all\"]\n",
    "#             torch.save(model.state_dict(), save_path)\n",
    "#             print(f\"✅ Best model saved to {save_path} (4D: {best_val_acc_all:.2%})\")\n",
    "\n",
    "# # ====== 评估 ======\n",
    "# def evaluate(model, dataloader, device):\n",
    "#     model.eval()\n",
    "#     total = 0\n",
    "#     correct_ei = correct_ns = correct_tf = correct_jp = correct_all = 0\n",
    "\n",
    "#     with torch.no_grad():\n",
    "#         for batch in dataloader:\n",
    "#             outputs = model(batch[\"input_ids\"].to(device), batch[\"attention_mask\"].to(device))\n",
    "#             labels = {k: batch[k].to(device) for k in [\"ei\", \"ns\", \"tf\", \"jp\"]}\n",
    "\n",
    "#             pred_ei = torch.sigmoid(outputs[\"ei\"]).round().squeeze()\n",
    "#             pred_ns = torch.sigmoid(outputs[\"ns\"]).round().squeeze()\n",
    "#             pred_tf = torch.sigmoid(outputs[\"tf\"]).round().squeeze()\n",
    "#             pred_jp = torch.sigmoid(outputs[\"jp\"]).round().squeeze()\n",
    "\n",
    "#             correct_ei += (pred_ei == labels[\"ei\"]).sum().item()\n",
    "#             correct_ns += (pred_ns == labels[\"ns\"]).sum().item()\n",
    "#             correct_tf += (pred_tf == labels[\"tf\"]).sum().item()\n",
    "#             correct_jp += (pred_jp == labels[\"jp\"]).sum().item()\n",
    "\n",
    "#             correct_all += ((pred_ei == labels[\"ei\"]) &\n",
    "#                             (pred_ns == labels[\"ns\"]) &\n",
    "#                             (pred_tf == labels[\"tf\"]) &\n",
    "#                             (pred_jp == labels[\"jp\"])).sum().item()\n",
    "\n",
    "#             total += batch[\"input_ids\"].size(0)\n",
    "\n",
    "#     return {\n",
    "#         \"acc_ei\": correct_ei / total,\n",
    "#         \"acc_ns\": correct_ns / total,\n",
    "#         \"acc_tf\": correct_tf / total,\n",
    "#         \"acc_jp\": correct_jp / total,\n",
    "#         \"acc_all\": correct_all / total\n",
    "#     }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "97e056af",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 1/30:   0%|          | 0/3600 [00:00<?, ?it/s]You're using a DebertaV2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "[Train] Epoch 1/30: 100%|██████████| 3600/3600 [21:07<00:00,  2.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1 - Train Loss: 1.0725\n",
      "Validation | EI: 62.60% NS: 63.54% TF: 67.90% JP: 60.35% 4D: 19.75%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 19.75%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 2/30: 100%|██████████| 3600/3600 [20:45<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 2 - Train Loss: 0.9754\n",
      "Validation | EI: 64.81% NS: 66.81% TF: 63.31% JP: 65.52% 4D: 21.62%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 21.62%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 3/30: 100%|██████████| 3600/3600 [20:46<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 3 - Train Loss: 0.9045\n",
      "Validation | EI: 70.00% NS: 70.96% TF: 71.44% JP: 68.29% 4D: 30.83%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 30.83%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 4/30: 100%|██████████| 3600/3600 [20:48<00:00,  2.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 4 - Train Loss: 0.8403\n",
      "Validation | EI: 71.65% NS: 73.10% TF: 73.71% JP: 69.17% 4D: 33.65%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 33.65%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 5/30: 100%|██████████| 3600/3600 [20:46<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 5 - Train Loss: 0.7806\n",
      "Validation | EI: 71.17% NS: 73.79% TF: 74.12% JP: 68.62% 4D: 34.23%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 34.23%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 6/30: 100%|██████████| 3600/3600 [20:45<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 6 - Train Loss: 0.7217\n",
      "Validation | EI: 72.77% NS: 73.85% TF: 74.77% JP: 72.40% 4D: 37.56%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 37.56%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 7/30: 100%|██████████| 3600/3600 [20:47<00:00,  2.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 7 - Train Loss: 0.6666\n",
      "Validation | EI: 72.29% NS: 74.46% TF: 74.58% JP: 71.25% 4D: 36.73%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 8/30: 100%|██████████| 3600/3600 [20:47<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 8 - Train Loss: 0.6152\n",
      "Validation | EI: 72.94% NS: 74.48% TF: 75.21% JP: 71.56% 4D: 38.46%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 38.46%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 9/30: 100%|██████████| 3600/3600 [20:46<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 9 - Train Loss: 0.5659\n",
      "Validation | EI: 72.62% NS: 75.33% TF: 75.60% JP: 72.94% 4D: 39.33%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 39.33%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 10/30: 100%|██████████| 3600/3600 [20:46<00:00,  2.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 10 - Train Loss: 0.5256\n",
      "Validation | EI: 73.17% NS: 74.54% TF: 75.38% JP: 71.04% 4D: 38.96%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 11/30:  42%|████▏     | 1497/3600 [08:38<12:08,  2.89it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 341\u001b[39m\n\u001b[32m    338\u001b[39m optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n\u001b[32m    340\u001b[39m \u001b[38;5;66;03m# 训练\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m341\u001b[39m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDEVICE\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mEPOCHS\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mSAVE_PATH\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    343\u001b[39m \u001b[38;5;66;03m# 推理示例（主文本）\u001b[39;00m\n\u001b[32m    344\u001b[39m sample = val_data[\u001b[32m0\u001b[39m]\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 292\u001b[39m, in \u001b[36mtrain_model\u001b[39m\u001b[34m(model, train_loader, val_loader, optimizer, device, epochs, save_path)\u001b[39m\n\u001b[32m    290\u001b[39m scaler.unscale_(optimizer)\n\u001b[32m    291\u001b[39m torch.nn.utils.clip_grad_norm_(model.parameters(), \u001b[32m1.0\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m292\u001b[39m \u001b[43mscaler\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m)\u001b[49m; scaler.update()\n\u001b[32m    293\u001b[39m optimizer.zero_grad(set_to_none=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m    294\u001b[39m scheduler.step()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/amp/grad_scaler.py:461\u001b[39m, in \u001b[36mGradScaler.step\u001b[39m\u001b[34m(self, optimizer, *args, **kwargs)\u001b[39m\n\u001b[32m    455\u001b[39m     \u001b[38;5;28mself\u001b[39m.unscale_(optimizer)\n\u001b[32m    457\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[32m    458\u001b[39m     \u001b[38;5;28mlen\u001b[39m(optimizer_state[\u001b[33m\"\u001b[39m\u001b[33mfound_inf_per_device\u001b[39m\u001b[33m\"\u001b[39m]) > \u001b[32m0\u001b[39m\n\u001b[32m    459\u001b[39m ), \u001b[33m\"\u001b[39m\u001b[33mNo inf checks were recorded for this optimizer.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m461\u001b[39m retval = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_maybe_opt_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    463\u001b[39m optimizer_state[\u001b[33m\"\u001b[39m\u001b[33mstage\u001b[39m\u001b[33m\"\u001b[39m] = OptState.STEPPED\n\u001b[32m    465\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/amp/grad_scaler.py:355\u001b[39m, in \u001b[36mGradScaler._maybe_opt_step\u001b[39m\u001b[34m(self, optimizer, optimizer_state, *args, **kwargs)\u001b[39m\n\u001b[32m    347\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[34m_maybe_opt_step\u001b[39m(\n\u001b[32m    348\u001b[39m     \u001b[38;5;28mself\u001b[39m,\n\u001b[32m    349\u001b[39m     optimizer: torch.optim.Optimizer,\n\u001b[32m   (...)\u001b[39m\u001b[32m    352\u001b[39m     **kwargs: Any,\n\u001b[32m    353\u001b[39m ) -> Optional[\u001b[38;5;28mfloat\u001b[39m]:\n\u001b[32m    354\u001b[39m     retval: Optional[\u001b[38;5;28mfloat\u001b[39m] = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m355\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;43msum\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43moptimizer_state\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mfound_inf_per_device\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m    356\u001b[39m         retval = optimizer.step(*args, **kwargs)\n\u001b[32m    357\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/amp/grad_scaler.py:355\u001b[39m, in \u001b[36m<genexpr>\u001b[39m\u001b[34m(.0)\u001b[39m\n\u001b[32m    347\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[34m_maybe_opt_step\u001b[39m(\n\u001b[32m    348\u001b[39m     \u001b[38;5;28mself\u001b[39m,\n\u001b[32m    349\u001b[39m     optimizer: torch.optim.Optimizer,\n\u001b[32m   (...)\u001b[39m\u001b[32m    352\u001b[39m     **kwargs: Any,\n\u001b[32m    353\u001b[39m ) -> Optional[\u001b[38;5;28mfloat\u001b[39m]:\n\u001b[32m    354\u001b[39m     retval: Optional[\u001b[38;5;28mfloat\u001b[39m] = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m355\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28msum\u001b[39m(\u001b[43mv\u001b[49m\u001b[43m.\u001b[49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m optimizer_state[\u001b[33m\"\u001b[39m\u001b[33mfound_inf_per_device\u001b[39m\u001b[33m\"\u001b[39m].values()):\n\u001b[32m    356\u001b[39m         retval = optimizer.step(*args, **kwargs)\n\u001b[32m    357\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m retval\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "# ====================== 最优版整体代码（含自定义 collate_fn） ======================\n",
    "import os, json, random\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # 可选：静默 tokenizer 并行告警\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import (\n",
    "    AutoTokenizer, AutoModel, get_linear_schedule_with_warmup\n",
    ")\n",
    "from torch import amp\n",
    "from tqdm import tqdm\n",
    "from sklearn.model_selection import train_test_split\n",
    "import contextlib\n",
    "\n",
    "# ===================== 配置 =====================\n",
    "MODEL_NAME   = \"microsoft/deberta-v3-base\"\n",
    "FILE_PATH    = \"mbti_sample_with_all_views_pandora.json\"\n",
    "MAX_LEN      = 320        # 建议：320；想更快可改 288 做 A/B\n",
    "BATCH_SIZE   = 12         # 结合 MAX_LEN 的建议 batch\n",
    "GRAD_ACCUM   = 2          # 梯度累积，保证有效 tokens/step 足够大\n",
    "EPOCHS       = 30\n",
    "LR           = 2e-5\n",
    "WARMUP_RATIO = 0.06\n",
    "LAMBDA_CL    = 0.1        # 对比学习权重\n",
    "SAVE_PATH    = \"best_deberta_supcon.pt\"\n",
    "DEVICE       = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# ===================== 工具函数 =====================\n",
    "def mbti_to_binary(mbti_type):\n",
    "    return {\n",
    "        \"ei\": 0 if mbti_type[0] == \"I\" else 1,\n",
    "        \"ns\": 0 if mbti_type[1] == \"S\" else 1,\n",
    "        \"tf\": 0 if mbti_type[2] == \"F\" else 1,\n",
    "        \"jp\": 0 if mbti_type[3] == \"P\" else 1\n",
    "    }\n",
    "\n",
    "def mbti_to_classid(mbti_type):\n",
    "    types = [\n",
    "        \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\n",
    "        \"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "        \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\n",
    "        \"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "    ]\n",
    "    return types.index(mbti_type)\n",
    "\n",
    "def random_drop_words(text, drop_prob=0.1):\n",
    "    words = text.split()\n",
    "    keep = [w for w in words if random.random() > drop_prob]\n",
    "    return \" \".join(keep) if keep else text\n",
    "\n",
    "# ===================== Dataset（只截断，不在此处padding） =====================\n",
    "class JointMBTIDataset(Dataset):\n",
    "    def __init__(self, data, tokenizer, max_length=320, drop_prob=0.1):\n",
    "        self.data = data\n",
    "        self.tok = tokenizer\n",
    "        self.max_length = max_length\n",
    "        self.drop_prob = drop_prob\n",
    "\n",
    "    def _encode(self, text):\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_length)\n",
    "        return enc[\"input_ids\"], enc[\"attention_mask\"]\n",
    "\n",
    "    def __len__(self): return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        base_text = item.get(\"posts_cleaned\", item.get(\"posts\", \"\")) or \"\"\n",
    "\n",
    "        # 三视角（随机DropWord增强）\n",
    "        semantic   = random_drop_words(item.get(\"semantic_view\", \"\"), self.drop_prob) or base_text\n",
    "        sentiment  = random_drop_words(item.get(\"sentiment_view\",\"\"), self.drop_prob) or base_text\n",
    "        linguistic = random_drop_words(item.get(\"linguistic_view\",\"\"), self.drop_prob) or base_text\n",
    "\n",
    "        mbti_label = mbti_to_binary(item[\"type\"])\n",
    "        class_id   = mbti_to_classid(item[\"type\"])\n",
    "\n",
    "        base_ids, base_m = self._encode(base_text)\n",
    "        sem_ids,  sem_m  = self._encode(semantic)\n",
    "        sen_ids,  sen_m  = self._encode(sentiment)\n",
    "        lin_ids,  lin_m  = self._encode(linguistic)\n",
    "\n",
    "        return {\n",
    "            \"input_ids\": base_ids, \"attention_mask\": base_m,\n",
    "            \"semantic_ids\": sem_ids, \"semantic_mask\": sem_m,\n",
    "            \"sentiment_ids\": sen_ids, \"sentiment_mask\": sen_m,\n",
    "            \"linguistic_ids\": lin_ids, \"linguistic_mask\": lin_m,\n",
    "            \"ei\": torch.tensor(mbti_label[\"ei\"], dtype=torch.float),\n",
    "            \"ns\": torch.tensor(mbti_label[\"ns\"], dtype=torch.float),\n",
    "            \"tf\": torch.tensor(mbti_label[\"tf\"], dtype=torch.float),\n",
    "            \"jp\": torch.tensor(mbti_label[\"jp\"], dtype=torch.float),\n",
    "            \"class_id\": torch.tensor(class_id, dtype=torch.long),\n",
    "            \"label\": item[\"type\"]\n",
    "        }\n",
    "\n",
    "# ===================== 自定义 collate：四个视角分别动态 padding =====================\n",
    "def make_collate_fn(tokenizer, pad_to_multiple_of=8):\n",
    "    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0\n",
    "\n",
    "    def pad_view(batch, ids_key, mask_key):\n",
    "        feats = [{\"input_ids\": b[ids_key], \"attention_mask\": b[mask_key]} for b in batch]\n",
    "        padded = tokenizer.pad(\n",
    "            feats,\n",
    "            padding=\"longest\",\n",
    "            return_tensors=\"pt\",\n",
    "            pad_to_multiple_of=pad_to_multiple_of\n",
    "        )\n",
    "        return padded[\"input_ids\"], padded[\"attention_mask\"]\n",
    "\n",
    "    def right_pad_to(x, L, value):\n",
    "        # x: (B, Lx) -> (B, L) 右填充\n",
    "        if x.size(1) == L: \n",
    "            return x\n",
    "        return F.pad(x, (0, L - x.size(1)), value=value)\n",
    "\n",
    "    def collate(batch):\n",
    "        # 四个视角分别 pad（到各自视角的最长）\n",
    "        base_ids, base_m = pad_view(batch, \"input_ids\",      \"attention_mask\")\n",
    "        sem_ids,  sem_m  = pad_view(batch, \"semantic_ids\",   \"semantic_mask\")\n",
    "        sen_ids,  sen_m  = pad_view(batch, \"sentiment_ids\",  \"sentiment_mask\")\n",
    "        lin_ids,  lin_m  = pad_view(batch, \"linguistic_ids\", \"linguistic_mask\")\n",
    "\n",
    "        # 统一到同一个 max_len（四个视角的最大）\n",
    "        L = max(base_ids.size(1), sem_ids.size(1), sen_ids.size(1), lin_ids.size(1))\n",
    "        base_ids = right_pad_to(base_ids, L, pad_id);   base_m = right_pad_to(base_m, L, 0)\n",
    "        sem_ids  = right_pad_to(sem_ids,  L, pad_id);   sem_m  = right_pad_to(sem_m,  L, 0)\n",
    "        sen_ids  = right_pad_to(sen_ids,  L, pad_id);   sen_m  = right_pad_to(sen_m,  L, 0)\n",
    "        lin_ids  = right_pad_to(lin_ids,  L, pad_id);   lin_m  = right_pad_to(lin_m,  L, 0)\n",
    "\n",
    "        return {\n",
    "            \"input_ids\":      base_ids, \"attention_mask\":      base_m,\n",
    "            \"semantic_ids\":   sem_ids,  \"semantic_mask\":       sem_m,\n",
    "            \"sentiment_ids\":  sen_ids,  \"sentiment_mask\":      sen_m,\n",
    "            \"linguistic_ids\": lin_ids,  \"linguistic_mask\":     lin_m,\n",
    "            \"ei\": torch.stack([b[\"ei\"] for b in batch]).float(),\n",
    "            \"ns\": torch.stack([b[\"ns\"] for b in batch]).float(),\n",
    "            \"tf\": torch.stack([b[\"tf\"] for b in batch]).float(),\n",
    "            \"jp\": torch.stack([b[\"jp\"] for b in batch]).float(),\n",
    "            \"class_id\": torch.stack([b[\"class_id\"] for b in batch]).long(),\n",
    "            \"label\": [b[\"label\"] for b in batch],\n",
    "        }\n",
    "    return collate\n",
    "\n",
    "# ===================== 模型 =====================\n",
    "class JointMBTIModel(nn.Module):\n",
    "    def __init__(self, encoder_name=MODEL_NAME, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.encoder = AutoModel.from_pretrained(encoder_name)\n",
    "        self.encoder.gradient_checkpointing_enable(\n",
    "    gradient_checkpointing_kwargs={\"use_reentrant\": False}\n",
    ")\n",
    "        hidden = self.encoder.config.hidden_size\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.classifier_ei = nn.Linear(hidden, 1)\n",
    "        self.classifier_ns = nn.Linear(hidden, 1)\n",
    "        self.classifier_tf = nn.Linear(hidden, 1)\n",
    "        self.classifier_jp = nn.Linear(hidden, 1)\n",
    "\n",
    "    def get_embedding(self, input_ids, attention_mask):\n",
    "        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)\n",
    "        cls = self.dropout(out.last_hidden_state[:, 0, :])\n",
    "        return cls\n",
    "\n",
    "    def forward_heads(self, cls_emb):\n",
    "        return {\n",
    "            \"embedding\": cls_emb,\n",
    "            \"ei\": self.classifier_ei(cls_emb),\n",
    "            \"ns\": self.classifier_ns(cls_emb),\n",
    "            \"tf\": self.classifier_tf(cls_emb),\n",
    "            \"jp\": self.classifier_jp(cls_emb)\n",
    "        }\n",
    "\n",
    "# ===================== 损失 =====================\n",
    "def compute_mbti_loss(preds, labels):\n",
    "    bce = nn.BCEWithLogitsLoss()\n",
    "    return (\n",
    "        bce(preds[\"ei\"].squeeze(), labels[\"ei\"]) +\n",
    "        bce(preds[\"ns\"].squeeze(), labels[\"ns\"]) +\n",
    "        bce(preds[\"tf\"].squeeze(), labels[\"tf\"]) +\n",
    "        bce(preds[\"jp\"].squeeze(), labels[\"jp\"])\n",
    "    ) / 4\n",
    "\n",
    "def supervised_contrastive_loss(embeddings, labels, temperature=0.07):\n",
    "    # embeddings: (N, H) ; labels: (N,)\n",
    "    device = embeddings.device\n",
    "    z = torch.nn.functional.normalize(embeddings, dim=1)\n",
    "    sim = torch.matmul(z, z.T) / temperature  # (N,N)\n",
    "\n",
    "    eye = torch.eye(len(labels), dtype=torch.bool, device=device)\n",
    "    labels = labels.view(-1, 1)\n",
    "    match = torch.eq(labels, labels.T).to(device) & ~eye\n",
    "\n",
    "    logits_max, _ = sim.max(dim=1, keepdim=True)\n",
    "    logits = sim - logits_max.detach()\n",
    "    exp_logits = torch.exp(logits) * (~eye)\n",
    "    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))\n",
    "\n",
    "    mean_log_prob_pos = (match * log_prob).sum(1) / match.sum(1).clamp(min=1)\n",
    "    return -(mean_log_prob_pos.mean())\n",
    "\n",
    "# ===================== 评估 =====================\n",
    "@torch.no_grad()\n",
    "def evaluate(model, dataloader, device):\n",
    "    model.eval()\n",
    "    total = 0\n",
    "    correct_ei = correct_ns = correct_tf = correct_jp = correct_all = 0\n",
    "    for batch in dataloader:\n",
    "        base_ids = batch[\"input_ids\"].to(device)\n",
    "        base_msk = batch[\"attention_mask\"].to(device)\n",
    "        cls = model.get_embedding(base_ids, base_msk)\n",
    "        out = model.forward_heads(cls)\n",
    "\n",
    "        labels = {k: batch[k].to(device) for k in [\"ei\",\"ns\",\"tf\",\"jp\"]}\n",
    "        pred_ei = torch.sigmoid(out[\"ei\"]).round().squeeze()\n",
    "        pred_ns = torch.sigmoid(out[\"ns\"]).round().squeeze()\n",
    "        pred_tf = torch.sigmoid(out[\"tf\"]).round().squeeze()\n",
    "        pred_jp = torch.sigmoid(out[\"jp\"]).round().squeeze()\n",
    "\n",
    "        correct_ei += (pred_ei == labels[\"ei\"]).sum().item()\n",
    "        correct_ns += (pred_ns == labels[\"ns\"]).sum().item()\n",
    "        correct_tf += (pred_tf == labels[\"tf\"]).sum().item()\n",
    "        correct_jp += (pred_jp == labels[\"jp\"]).sum().item()\n",
    "        correct_all += ((pred_ei == labels[\"ei\"]) &\n",
    "                        (pred_ns == labels[\"ns\"]) &\n",
    "                        (pred_tf == labels[\"tf\"]) &\n",
    "                        (pred_jp == labels[\"jp\"])).sum().item()\n",
    "        total += base_ids.size(0)\n",
    "\n",
    "    return {\n",
    "        \"acc_ei\": correct_ei / total,\n",
    "        \"acc_ns\": correct_ns / total,\n",
    "        \"acc_tf\": correct_tf / total,\n",
    "        \"acc_jp\": correct_jp / total,\n",
    "        \"acc_all\": correct_all / total\n",
    "    }\n",
    "\n",
    "# ===================== 训练 =====================\n",
    "def train_model(model, train_loader, val_loader, optimizer, device, epochs=5, save_path=\"best.pt\"):\n",
    "    model.to(device)\n",
    "    scaler = amp.GradScaler(\"cuda\") if device.type == \"cuda\" else None\n",
    "\n",
    "    total_updates = (len(train_loader) // GRAD_ACCUM) * epochs\n",
    "    warmup_steps  = max(1, int(total_updates * WARMUP_RATIO))\n",
    "    scheduler = get_linear_schedule_with_warmup(\n",
    "        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_updates\n",
    "    )\n",
    "\n",
    "    best_all = 0.0\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        total_loss = 0.0\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "        for step, batch in enumerate(tqdm(train_loader, desc=f\"[Train] Epoch {epoch+1}/{epochs}\")):\n",
    "            # ---- 四视角合并为一次前向 ----\n",
    "            base_ids = batch[\"input_ids\"].to(device)\n",
    "            base_msk = batch[\"attention_mask\"].to(device)\n",
    "            sem_ids  = batch[\"semantic_ids\"].to(device)\n",
    "            sem_msk  = batch[\"semantic_mask\"].to(device)\n",
    "            sen_ids  = batch[\"sentiment_ids\"].to(device)\n",
    "            sen_msk  = batch[\"sentiment_mask\"].to(device)\n",
    "            lin_ids  = batch[\"linguistic_ids\"].to(device)\n",
    "            lin_msk  = batch[\"linguistic_mask\"].to(device)\n",
    "\n",
    "            all_ids = torch.cat([base_ids, sem_ids, sen_ids, lin_ids], dim=0)   # (4B, L)\n",
    "            all_msk = torch.cat([base_msk, sem_msk, sen_msk, lin_msk], dim=0)\n",
    "            B = base_ids.size(0)\n",
    "\n",
    "            labels_bin = {k: batch[k].to(device) for k in [\"ei\",\"ns\",\"tf\",\"jp\"]}\n",
    "            class_id   = batch[\"class_id\"].to(device)\n",
    "\n",
    "            ctx = amp.autocast(\"cuda\") if device.type == \"cuda\" else contextlib.nullcontext()\n",
    "            with ctx:\n",
    "                emb_all  = model.get_embedding(all_ids, all_msk)      # (4B, H)\n",
    "                emb_base = emb_all[:B]\n",
    "                heads    = model.forward_heads(emb_base)\n",
    "\n",
    "                loss_cls = compute_mbti_loss(heads, labels_bin)\n",
    "                all_labels = class_id.repeat(4)                       # (4B,)\n",
    "                loss_cl  = supervised_contrastive_loss(emb_all, all_labels)\n",
    "\n",
    "                loss = loss_cls + LAMBDA_CL * loss_cl\n",
    "\n",
    "            loss = loss / GRAD_ACCUM\n",
    "\n",
    "            if scaler:\n",
    "                scaler.scale(loss).backward()\n",
    "                if (step + 1) % GRAD_ACCUM == 0:\n",
    "                    scaler.unscale_(optimizer)\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "                    scaler.step(optimizer); scaler.update()\n",
    "                    optimizer.zero_grad(set_to_none=True)\n",
    "                    scheduler.step()\n",
    "            else:\n",
    "                loss.backward()\n",
    "                if (step + 1) % GRAD_ACCUM == 0:\n",
    "                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "                    optimizer.step(); optimizer.zero_grad(set_to_none=True)\n",
    "                    scheduler.step()\n",
    "\n",
    "            total_loss += loss.item() * GRAD_ACCUM  # 还原显示\n",
    "\n",
    "        avg_loss = total_loss / len(train_loader)\n",
    "        print(f\"\\nEpoch {epoch+1} - Train Loss: {avg_loss:.4f}\")\n",
    "\n",
    "        metrics = evaluate(model, val_loader, device)\n",
    "        print(f\"Validation | EI: {metrics['acc_ei']:.2%} NS: {metrics['acc_ns']:.2%} TF: {metrics['acc_tf']:.2%} JP: {metrics['acc_jp']:.2%} 4D: {metrics['acc_all']:.2%}\")\n",
    "\n",
    "        if metrics[\"acc_all\"] > best_all:\n",
    "            best_all = metrics[\"acc_all\"]\n",
    "            torch.save(model.state_dict(), save_path)\n",
    "            print(f\"✅ Best model saved to {save_path} (4D: {best_all:.2%})\")\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "# ===================== 主流程 =====================\n",
    "if __name__ == \"__main__\":\n",
    "    # 读数据 & 划分\n",
    "    with open(FILE_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "        all_data = json.load(f)\n",
    "    train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)\n",
    "\n",
    "    # tokenizer & 自定义 collate（四视角分别动态 padding）\n",
    "    tokenizer  = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "    collate_fn = make_collate_fn(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    train_ds = JointMBTIDataset(train_data, tokenizer, max_length=MAX_LEN, drop_prob=0.1)\n",
    "    val_ds   = JointMBTIDataset(val_data,   tokenizer, max_length=MAX_LEN, drop_prob=0.0)\n",
    "\n",
    "    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n",
    "                              collate_fn=collate_fn, num_workers=0, pin_memory=True)\n",
    "    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,\n",
    "                              collate_fn=collate_fn, num_workers=0, pin_memory=True)\n",
    "\n",
    "    # 模型 & 优化器\n",
    "    model = JointMBTIModel(MODEL_NAME)\n",
    "    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)\n",
    "\n",
    "    # 训练\n",
    "    train_model(model, train_loader, val_loader, optimizer, DEVICE, epochs=EPOCHS, save_path=SAVE_PATH)\n",
    "\n",
    "    # 推理示例（主文本）\n",
    "    sample = val_data[0]\n",
    "    enc = tokenizer(sample.get(\"posts_cleaned\", sample.get(\"posts\",\"\")),\n",
    "                    truncation=True, max_length=MAX_LEN, return_tensors=\"pt\")\n",
    "    with torch.no_grad():\n",
    "        cls = model.get_embedding(enc[\"input_ids\"].to(DEVICE), enc[\"attention_mask\"].to(DEVICE))\n",
    "        out = model.forward_heads(cls)\n",
    "        pred_ei = torch.sigmoid(out[\"ei\"]).round().item()\n",
    "        pred_ns = torch.sigmoid(out[\"ns\"]).round().item()\n",
    "        pred_tf = torch.sigmoid(out[\"tf\"]).round().item()\n",
    "        pred_jp = torch.sigmoid(out[\"jp\"]).round().item()\n",
    "    print(\"原标签:\", sample[\"type\"])\n",
    "    print(\"预测四维:\", pred_ei, pred_ns, pred_tf, pred_jp)\n",
    "# ====================== 结束 ======================\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d765051a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.88s/it]\n",
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_126962/3440939658.py:231: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='31' max='5400' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [  31/5400 06:25 < 19:49:40, 0.08 it/s, Epoch 0.02/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Meta-Llama-3-8B-Instruct + QLoRA(4bit) + LoRA\n",
    "任务：MBTI 16类分类（同时统计4D严格准确率）\n",
    "适配 Transformers==4.55（使用 eval_strategy）\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any, List\n",
    "\n",
    "# —— 关键：禁用 accelerate 混合精度，避免对量化模型做 .to(dtype) —— #\n",
    "os.environ[\"ACCELERATE_MIXED_PRECISION\"] = \"no\"\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    BitsAndBytesConfig,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# 可选：启用 flash-attn-2（未安装会自动忽略）\n",
    "try:\n",
    "    from transformers import set_attn_implementation\n",
    "    set_attn_implementation(\"flash_attention_2\")\n",
    "except Exception:\n",
    "    pass\n",
    "\n",
    "# ============ 配置 ============\n",
    "MODEL_NAME   = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
    "FILE_PATH    = \"mbti_sample_with_all_views_pandora.json\"\n",
    "\n",
    "# 序列预算（总长约 768）\n",
    "MAX_LEN      = 768\n",
    "BUDGET = {\"posts_cleaned\": 384, \"semantic_view\": 128, \"sentiment_view\": 128, \"linguistic_view\": 128}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 16\n",
    "BSZ_EVAL     = 32\n",
    "GRAD_ACCUM   = 2\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_lora_llama8b_ckpt\"\n",
    "\n",
    "# QLoRA & LoRA\n",
    "USE_4BIT     = True\n",
    "LORA_R       = 16\n",
    "LORA_ALPHA   = 32\n",
    "LORA_DROPOUT = 0.05\n",
    "TARGET_MODULES = [\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"]\n",
    "\n",
    "# HF token（如有门控模型），建议 export HF_TOKEN=xxx\n",
    "HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
    "HF_KW = {\"token\": HF_TOKEN} if HF_TOKEN else {}\n",
    "\n",
    "# ============ MBTI 工具 ============\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    # I/E, S/N, F/T, P/J -> 0/1\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    return [r for r in rows if r.get(\"type\") in MBTI2ID]\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "# ============ Dataset ============\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ============ 指标 ============\n",
    "def compute_metrics(eval_pred):\n",
    "    # 兼容 transformers 各版本返回：可能是 (preds, labels) 或 EvalPrediction\n",
    "    if isinstance(eval_pred, tuple):\n",
    "        preds, labels = eval_pred\n",
    "    else:\n",
    "        preds, labels = eval_pred.predictions, eval_pred.label_ids\n",
    "\n",
    "    # preds 可能是 (logits,) 或 numpy\n",
    "    if isinstance(preds, (list, tuple)):\n",
    "        preds = preds[0]\n",
    "    if not isinstance(preds, np.ndarray):\n",
    "        preds = np.asarray(preds)\n",
    "    if not isinstance(labels, np.ndarray):\n",
    "        labels = np.asarray(labels)\n",
    "\n",
    "    pred_ids = preds.argmax(-1)\n",
    "    acc16 = float((pred_ids == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in pred_ids]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\"acc_16\": acc16, \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n, \"acc_4D\": c_all/n}\n",
    "\n",
    "# ============ 主流程 ============\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    # 数据\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        rows, test_size=0.1, random_state=SEED, stratify=[r[\"type\"] for r in rows]\n",
    "    )\n",
    "\n",
    "    # tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, **HF_KW)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # QLoRA 量化配置（注意：不要再传 torch_dtype）\n",
    "    quant_cfg = BitsAndBytesConfig(\n",
    "        load_in_4bit=USE_4BIT,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type=\"nf4\",\n",
    "        bnb_4bit_compute_dtype=torch.float16,  # 稳定优先；新 bnb 版本也可换 torch.bfloat16\n",
    "    ) if USE_4BIT else None\n",
    "\n",
    "    # 模型（16类分类头）\n",
    "    model_kwargs = dict(\n",
    "        num_labels=16,\n",
    "        quantization_config=quant_cfg,\n",
    "        device_map=\"auto\",\n",
    "        **HF_KW,\n",
    "    )\n",
    "    if not USE_4BIT:\n",
    "        model_kwargs[\"torch_dtype\"] = torch.bfloat16\n",
    "\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, **model_kwargs)\n",
    "\n",
    "    # —— 保险：禁止任何对量化模型的 .to() —— #\n",
    "    def _noop_to(self, *args, **kwargs):\n",
    "        return self\n",
    "    model.to = _noop_to.__get__(model, type(model))\n",
    "\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "    model.resize_token_embeddings(len(tokenizer))\n",
    "\n",
    "    # LoRA（含 k-bit 训练准备）\n",
    "    from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training\n",
    "    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,\n",
    "        target_modules=TARGET_MODULES, bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "    model.to = _noop_to.__get__(model, type(model))  # 再补一次\n",
    "\n",
    "    # 数据 & collator\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # 训练参数（Transformers 4.55 使用 eval_strategy）\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        eval_strategy=\"epoch\",      # ← 4.55 新名字\n",
    "        save_strategy=\"epoch\",\n",
    "        save_total_limit=2,\n",
    "        logging_steps=50,\n",
    "        bf16=False,                 # 量化 + 混合精度容易触发 cast，统一关闭\n",
    "        fp16=False,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()\n",
    "    eval_metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Final Eval ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 简单推理示例（只移动 inputs 到 model.device）\n",
    "    model.eval()\n",
    "    sample = val_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(next(model.parameters()).device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n原标签:\", sample[\"type\"], \" | 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "81e3a67c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PYTHON: /home/hli962/miniconda3/bin/python\n",
      "transformers: 4.55.4 -> /home/hli962/miniconda3/lib/python3.12/site-packages/transformers\n",
      "accelerate: 1.10.1\n",
      "peft: 0.17.1\n"
     ]
    }
   ],
   "source": [
    "import sys, os, transformers, accelerate, peft\n",
    "print(\"PYTHON:\", sys.executable)\n",
    "print(\"transformers:\", transformers.__version__, \"->\", os.path.dirname(transformers.__file__))\n",
    "print(\"accelerate:\", accelerate.__version__)\n",
    "print(\"peft:\", peft.__version__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b82face",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/bin/python\n",
      "transformers: 4.55.4 -> /home/hli962/miniconda3/lib/python3.12/site-packages/transformers/__init__.py\n"
     ]
    }
   ],
   "source": [
    "import sys, transformers\n",
    "print(sys.executable)\n",
    "print(\"transformers:\", transformers.__version__, \"->\", transformers.__file__)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dc58ca83",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 18,489,344 || all params: 1,561,811,968 || trainable%: 1.1838\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_126962/1371604356.py:196: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='17891' max='21600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [17891/21600 2:51:14 < 35:30, 1.74 it/s, Epoch 3.31/4]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Acc 16</th>\n",
       "      <th>Acc Ei</th>\n",
       "      <th>Acc Ns</th>\n",
       "      <th>Acc Tf</th>\n",
       "      <th>Acc Jp</th>\n",
       "      <th>Acc 4d</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>4.080800</td>\n",
       "      <td>2.028699</td>\n",
       "      <td>0.357083</td>\n",
       "      <td>0.667083</td>\n",
       "      <td>0.695417</td>\n",
       "      <td>0.705208</td>\n",
       "      <td>0.676458</td>\n",
       "      <td>0.357083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.039500</td>\n",
       "      <td>1.809326</td>\n",
       "      <td>0.453333</td>\n",
       "      <td>0.731667</td>\n",
       "      <td>0.749167</td>\n",
       "      <td>0.746458</td>\n",
       "      <td>0.709583</td>\n",
       "      <td>0.453333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>1.760800</td>\n",
       "      <td>2.001467</td>\n",
       "      <td>0.475833</td>\n",
       "      <td>0.743125</td>\n",
       "      <td>0.752292</td>\n",
       "      <td>0.760208</td>\n",
       "      <td>0.726458</td>\n",
       "      <td>0.475833</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 231\u001b[39m\n\u001b[32m    228\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m原标签:\u001b[39m\u001b[33m\"\u001b[39m, sample[\u001b[33m\"\u001b[39m\u001b[33mtype\u001b[39m\u001b[33m\"\u001b[39m], \u001b[33m\"\u001b[39m\u001b[33m| 预测:\u001b[39m\u001b[33m\"\u001b[39m, pred_mbti)\n\u001b[32m    230\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[34m__name__\u001b[39m == \u001b[33m\"\u001b[39m\u001b[33m__main__\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m231\u001b[39m     \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 206\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m    175\u001b[39m args = TrainingArguments(\n\u001b[32m    176\u001b[39m     output_dir=OUTPUT_DIR,\n\u001b[32m    177\u001b[39m     per_device_train_batch_size=BSZ_TRN,\n\u001b[32m   (...)\u001b[39m\u001b[32m    193\u001b[39m     fp16=\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m    194\u001b[39m )\n\u001b[32m    196\u001b[39m trainer = Trainer(\n\u001b[32m    197\u001b[39m     model=model,\n\u001b[32m    198\u001b[39m     args=args,\n\u001b[32m   (...)\u001b[39m\u001b[32m    203\u001b[39m     compute_metrics=compute_metrics,\n\u001b[32m    204\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m206\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    207\u001b[39m eval_metrics = trainer.evaluate()\n\u001b[32m    208\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m=== Final Eval ===\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2238\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m   2236\u001b[39m         hf_hub_utils.enable_progress_bars()\n\u001b[32m   2237\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2238\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m   2239\u001b[39m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   2240\u001b[39m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   2241\u001b[39m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   2242\u001b[39m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m   2243\u001b[39m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:2582\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m   2575\u001b[39m context = (\n\u001b[32m   2576\u001b[39m     functools.partial(\u001b[38;5;28mself\u001b[39m.accelerator.no_sync, model=model)\n\u001b[32m   2577\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m i != \u001b[38;5;28mlen\u001b[39m(batch_samples) - \u001b[32m1\u001b[39m\n\u001b[32m   2578\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type != DistributedType.DEEPSPEED\n\u001b[32m   2579\u001b[39m     \u001b[38;5;28;01melse\u001b[39;00m contextlib.nullcontext\n\u001b[32m   2580\u001b[39m )\n\u001b[32m   2581\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m context():\n\u001b[32m-> \u001b[39m\u001b[32m2582\u001b[39m     tr_loss_step = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_items_in_batch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   2584\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m   2585\u001b[39m     args.logging_nan_inf_filter\n\u001b[32m   2586\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[32m   2587\u001b[39m     \u001b[38;5;129;01mand\u001b[39;00m (torch.isnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch.isinf(tr_loss_step))\n\u001b[32m   2588\u001b[39m ):\n\u001b[32m   2589\u001b[39m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[32m   2590\u001b[39m     tr_loss = tr_loss + tr_loss / (\u001b[32m1\u001b[39m + \u001b[38;5;28mself\u001b[39m.state.global_step - \u001b[38;5;28mself\u001b[39m._globalstep_last_logged)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3845\u001b[39m, in \u001b[36mTrainer.training_step\u001b[39m\u001b[34m(***failed resolving arguments***)\u001b[39m\n\u001b[32m   3842\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.accelerator.distributed_type == DistributedType.DEEPSPEED:\n\u001b[32m   3843\u001b[39m     kwargs[\u001b[33m\"\u001b[39m\u001b[33mscale_wrt_gas\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m3845\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   3847\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m loss.detach()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/accelerate/accelerator.py:2734\u001b[39m, in \u001b[36mAccelerator.backward\u001b[39m\u001b[34m(self, loss, **kwargs)\u001b[39m\n\u001b[32m   2732\u001b[39m     \u001b[38;5;28mself\u001b[39m.lomo_backward(loss, learning_rate)\n\u001b[32m   2733\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2734\u001b[39m     \u001b[43mloss\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/_tensor.py:648\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m    638\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m    639\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m    640\u001b[39m         Tensor.backward,\n\u001b[32m    641\u001b[39m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m   (...)\u001b[39m\u001b[32m    646\u001b[39m         inputs=inputs,\n\u001b[32m    647\u001b[39m     )\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    649\u001b[39m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m    650\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/__init__.py:353\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m    348\u001b[39m     retain_graph = create_graph\n\u001b[32m    350\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m    351\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m    352\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m353\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    354\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    355\u001b[39m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    356\u001b[39m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    357\u001b[39m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    358\u001b[39m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    359\u001b[39m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m    360\u001b[39m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m    361\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:824\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m    822\u001b[39m     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m    823\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m824\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m    825\u001b[39m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m    826\u001b[39m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m    827\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m    828\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "Qwen-2.5-1.5B + LoRA 训练 MBTI 16类，同时统计4D严格准确率\n",
    "（不做量化，避免 .to() / bitsandbytes 的兼容问题；适配 transformers 4.55+）\n",
    "\"\"\"\n",
    "\n",
    "import os, json\n",
    "from typing import Dict, Any, List\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    AutoModelForSequenceClassification,\n",
    "    DataCollatorWithPadding,\n",
    "    Trainer, TrainingArguments,\n",
    "    set_seed,\n",
    ")\n",
    "\n",
    "# ---------------- 基本配置 ----------------\n",
    "MODEL_NAME = \"Qwen/Qwen2.5-1.5B\"              # 也可用 \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "FILE_PATH  = \"mbti_sample_with_all_views_pandora.json\"\n",
    "\n",
    "MAX_LEN = 768\n",
    "BUDGET = {\n",
    "    \"posts_cleaned\": 384,\n",
    "    \"semantic_view\": 128,\n",
    "    \"sentiment_view\": 128,\n",
    "    \"linguistic_view\": 128,\n",
    "}\n",
    "\n",
    "SEED         = 42\n",
    "EPOCHS       = 4\n",
    "LR           = 2e-4\n",
    "BSZ_TRN      = 4           # 显存吃紧就再降\n",
    "BSZ_EVAL     = 8\n",
    "GRAD_ACCUM   = 2\n",
    "WARMUP_RATIO = 0.06\n",
    "WEIGHT_DECAY = 0.01\n",
    "OUTPUT_DIR   = \"mbti_qwen_lora_ckpt\"\n",
    "\n",
    "# ---------------- MBTI 工具 ----------------\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\n",
    "    \"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\n",
    "    \"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\",\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "def mbti_to_4d(m: str):\n",
    "    # I/E, S/N, F/T, P/J -> 0/1\n",
    "    return (\n",
    "        0 if m[0]==\"I\" else 1,\n",
    "        0 if m[1]==\"S\" else 1,\n",
    "        0 if m[2]==\"F\" else 1,\n",
    "        0 if m[3]==\"P\" else 1,\n",
    "    )\n",
    "\n",
    "def truncate_to_budget(tok: AutoTokenizer, text: str, budget: int) -> str:\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok: AutoTokenizer) -> str:\n",
    "    p   = truncate_to_budget(tok, item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or \"\", BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, item.get(\"semantic_view\",\"\")  or \"\", BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, item.get(\"sentiment_view\",\"\") or \"\", BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, item.get(\"linguistic_view\",\"\") or \"\", BUDGET[\"linguistic_view\"])\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n\"\n",
    "        f\"[SEMANTIC]\\n{sem}\\n\"\n",
    "        f\"[SENTIMENT]\\n{sen}\\n\"\n",
    "        f\"[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str) -> List[Dict[str, Any]]:\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    return [r for r in rows if r.get(\"type\") in MBTI2ID]\n",
    "\n",
    "# ---------------- Dataset ----------------\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it  = self.rows[idx]\n",
    "        text= build_input(it, self.tok)\n",
    "        y   = MBTI2ID[it[\"type\"]]\n",
    "        enc = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ---------------- 指标 ----------------\n",
    "def compute_metrics(eval_pred):\n",
    "    # 兼容 EvalPrediction / (preds, labels) / logits tuple\n",
    "    preds, labels = (eval_pred.predictions, eval_pred.label_ids) if hasattr(eval_pred, \"predictions\") else eval_pred\n",
    "    if isinstance(preds, tuple):          # 有些模型返回 (logits, ...)\n",
    "        preds = preds[0]\n",
    "    if isinstance(preds, torch.Tensor):\n",
    "        preds = preds.detach().cpu().numpy()\n",
    "    if isinstance(labels, torch.Tensor):\n",
    "        labels = labels.detach().cpu().numpy()\n",
    "\n",
    "    preds = preds.argmax(-1)\n",
    "    acc16 = float((preds == labels).mean())\n",
    "\n",
    "    pred_types = [MBTI_16[i] for i in preds]\n",
    "    true_types = [MBTI_16[i] for i in labels]\n",
    "\n",
    "    c_ei=c_ns=c_tf=c_jp=c_all=0\n",
    "    for pt, tt in zip(pred_types, true_types):\n",
    "        pei,pns,ptf,pjp = mbti_to_4d(pt)\n",
    "        tei,tns,ttf,tjp = mbti_to_4d(tt)\n",
    "        c_ei += (pei==tei); c_ns += (pns==tns); c_tf += (ptf==ttf); c_jp += (pjp==tjp)\n",
    "        c_all+= (pei==tei and pns==tns and ptf==ttf and pjp==tjp)\n",
    "    n = len(labels)\n",
    "    return {\n",
    "        \"acc_16\": acc16,\n",
    "        \"acc_ei\": c_ei/n, \"acc_ns\": c_ns/n, \"acc_tf\": c_tf/n, \"acc_jp\": c_jp/n,\n",
    "        \"acc_4D\": c_all/n,\n",
    "    }\n",
    "\n",
    "# ---------------- 训练主流程 ----------------\n",
    "def main():\n",
    "    set_seed(SEED)\n",
    "    torch.backends.cuda.matmul.allow_tf32 = True\n",
    "    torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "    rows = load_rows(FILE_PATH)\n",
    "    train_rows, val_rows = train_test_split(\n",
    "        rows, test_size=0.1, random_state=SEED, stratify=[r[\"type\"] for r in rows]\n",
    "    )\n",
    "\n",
    "    # Tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n",
    "    if tokenizer.pad_token is None:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "    tokenizer.padding_side = \"right\"\n",
    "\n",
    "    # 模型（不量化）+ 分类头\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        num_labels=16,\n",
    "        torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else None),\n",
    "        device_map=\"auto\",\n",
    "    )\n",
    "    model.config.pad_token_id = tokenizer.pad_token_id\n",
    "    model.config.use_cache = False\n",
    "    model.resize_token_embeddings(len(tokenizer))\n",
    "    if hasattr(model, \"gradient_checkpointing_enable\"):\n",
    "        model.gradient_checkpointing_enable()\n",
    "\n",
    "    # 只做 LoRA（不做 k-bit）\n",
    "    from peft import LoraConfig, TaskType, get_peft_model\n",
    "    peft_cfg = LoraConfig(\n",
    "        task_type=TaskType.SEQ_CLS,\n",
    "        r=16, lora_alpha=32, lora_dropout=0.05,\n",
    "        target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n",
    "        bias=\"none\"\n",
    "    )\n",
    "    model = get_peft_model(model, peft_cfg)\n",
    "    model.print_trainable_parameters()\n",
    "\n",
    "    train_ds = MBTIDataset(train_rows, tokenizer, max_len=MAX_LEN)\n",
    "    val_ds   = MBTIDataset(val_rows,   tokenizer, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)\n",
    "\n",
    "    # ✅ transformers 4.55+ 使用 eval_strategy\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_train_batch_size=BSZ_TRN,\n",
    "        per_device_eval_batch_size=BSZ_EVAL,\n",
    "        gradient_accumulation_steps=GRAD_ACCUM,\n",
    "        num_train_epochs=EPOCHS,\n",
    "        learning_rate=LR,\n",
    "        warmup_ratio=WARMUP_RATIO,\n",
    "        weight_decay=WEIGHT_DECAY,\n",
    "        logging_steps=50,\n",
    "        save_total_limit=2,\n",
    "        report_to=\"none\",\n",
    "        load_best_model_at_end=True,\n",
    "        metric_for_best_model=\"eval_acc_4D\",\n",
    "        greater_is_better=True,\n",
    "        eval_strategy=\"epoch\",      # <-- 新参数名\n",
    "        save_strategy=\"epoch\",\n",
    "        bf16=(torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8),\n",
    "        fp16=False,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_ds,\n",
    "        eval_dataset=val_ds,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()\n",
    "    eval_metrics = trainer.evaluate()\n",
    "    print(\"\\n=== Final Eval ===\")\n",
    "    for k, v in eval_metrics.items():\n",
    "        try:\n",
    "            print(f\"{k}: {float(v):.4f}\")\n",
    "        except Exception:\n",
    "            print(k, v)\n",
    "\n",
    "    trainer.save_model(OUTPUT_DIR)\n",
    "    print(f\"\\n✅ LoRA adapter saved to: {OUTPUT_DIR}\")\n",
    "\n",
    "    # 简单推理示例\n",
    "    model.eval()\n",
    "    sample = val_rows[0]\n",
    "    text = build_input(sample, tokenizer)\n",
    "    batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n",
    "    batch = {k: v.to(model.device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        logits = model(**batch).logits\n",
    "        pred_id = int(torch.argmax(logits, dim=-1))\n",
    "        pred_mbti = MBTI_16[pred_id]\n",
    "    print(\"\\n原标签:\", sample[\"type\"], \"| 预测:\", pred_mbti)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "17148ad0",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'JointMBTIDataset' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mNameError\u001b[39m                                 Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 17\u001b[39m\n\u001b[32m     14\u001b[39m \u001b[38;5;66;03m# ===== 3. 初始化Tokenizer和Dataset =====\u001b[39;00m\n\u001b[32m     15\u001b[39m tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n\u001b[32m---> \u001b[39m\u001b[32m17\u001b[39m train_dataset = \u001b[43mJointMBTIDataset\u001b[49m(train_data, tokenizer, max_length=MAX_LEN, drop_prob=\u001b[32m0.1\u001b[39m)\n\u001b[32m     18\u001b[39m val_dataset = JointMBTIDataset(val_data, tokenizer, max_length=MAX_LEN, drop_prob=\u001b[32m0.0\u001b[39m)  \u001b[38;5;66;03m# 验证集不drop\u001b[39;00m\n\u001b[32m     20\u001b[39m train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[31mNameError\u001b[39m: name 'JointMBTIDataset' is not defined"
     ]
    }
   ],
   "source": [
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# ===== 1. 读取数据 =====\n",
    "with open(\"mbti_sample_with_all_views.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# ===== 2. 划分训练/验证集 =====\n",
    "train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)\n",
    "\n",
    "# ===== 3. 初始化Tokenizer和Dataset =====\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "train_dataset = JointMBTIDataset(train_data, tokenizer, max_length=MAX_LEN, drop_prob=0.1)\n",
    "val_dataset = JointMBTIDataset(val_data, tokenizer, max_length=MAX_LEN, drop_prob=0.0)  # 验证集不drop\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "# ===== 4. 初始化模型和优化器 =====\n",
    "model = JointMBTIModel(encoder_name=MODEL_NAME)\n",
    "optimizer = AdamW(model.parameters(), lr=LR)\n",
    "\n",
    "# ===== 5. 开始训练 =====\n",
    "train_model(model, train_loader, val_loader, optimizer, DEVICE, epochs=EPOCHS, save_path=SAVE_PATH)\n",
    "\n",
    "# ===== 6. 加载最佳模型做推理 =====\n",
    "model.load_state_dict(torch.load(SAVE_PATH))\n",
    "model.eval()\n",
    "\n",
    "# 推理示例\n",
    "sample = val_data[0]\n",
    "encoding = tokenizer(\n",
    "    sample[\"posts_cleaned\"],\n",
    "    padding=\"max_length\",\n",
    "    truncation=True,\n",
    "    max_length=MAX_LEN,\n",
    "    return_tensors=\"pt\"\n",
    ")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(encoding[\"input_ids\"].to(DEVICE), encoding[\"attention_mask\"].to(DEVICE))\n",
    "    pred_ei = torch.sigmoid(outputs[\"ei\"]).round().item()\n",
    "    pred_ns = torch.sigmoid(outputs[\"ns\"]).round().item()\n",
    "    pred_tf = torch.sigmoid(outputs[\"tf\"]).round().item()\n",
    "    pred_jp = torch.sigmoid(outputs[\"jp\"]).round().item()\n",
    "\n",
    "print(\"原标签:\", sample[\"type\"])\n",
    "print(\"预测四维:\", pred_ei, pred_ns, pred_tf, pred_jp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "420da14c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "—— 前 10 条 posts_cleaned 的 token 数 ——\n",
      "样本 01: 244 tokens\n",
      "样本 02: 233 tokens\n",
      "样本 03: 276 tokens\n",
      "样本 04: 243 tokens\n",
      "样本 05: 271 tokens\n",
      "样本 06: 269 tokens\n",
      "样本 07: 207 tokens\n",
      "样本 08: 299 tokens\n",
      "样本 09: 239 tokens\n",
      "样本 10: 261 tokens\n",
      "\n",
      "—— 前 10 条统计 ——\n",
      "均值: 254.2\n",
      "中位数: 252.5\n",
      "P90: 278.3\n",
      "最大值: 299\n",
      "最小值: 207\n"
     ]
    }
   ],
   "source": [
    "# pip install transformers==4.43.0  (若未安装)\n",
    "from transformers import AutoTokenizer\n",
    "import json, os, math\n",
    "from statistics import mean, median\n",
    "\n",
    "# === 配置区 ===\n",
    "FILE_PATH = r\"mbti_sample_with_all_views_pandora.json\"  # ← 改成你的文件路径\n",
    "MODEL_NAME = \"microsoft/deberta-v3-base\"  # 你在用哪种模型就换成对应 tokenizer\n",
    "\n",
    "# === 载入 tokenizer ===\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "# === 读入 JSON ===\n",
    "if not os.path.exists(FILE_PATH):\n",
    "    raise FileNotFoundError(f\"文件不存在：{FILE_PATH}\")\n",
    "\n",
    "with open(FILE_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# data 预期是一个 list，每个元素是 dict，包含 \"posts_cleaned\"\n",
    "if not isinstance(data, list):\n",
    "    raise ValueError(\"JSON 顶层应为 list，每个元素为一条样本。\")\n",
    "\n",
    "def count_tokens(text: str) -> int:\n",
    "    # add_special_tokens=True 表示会包含 [CLS]/[SEP] 等特殊符号（和训练/推理一致）\n",
    "    enc = tokenizer(text, add_special_tokens=True, truncation=False)\n",
    "    return len(enc[\"input_ids\"])\n",
    "\n",
    "# === 计算前 10 条 ===\n",
    "first10 = data[:10]\n",
    "lens = []\n",
    "print(\"—— 前 10 条 posts_cleaned 的 token 数 ——\")\n",
    "for i, item in enumerate(first10, 1):\n",
    "    text = item.get(\"posts_cleaned\", \"\")\n",
    "    tok_len = count_tokens(text)\n",
    "    lens.append(tok_len)\n",
    "    print(f\"样本 {i:02d}: {tok_len} tokens\")\n",
    "\n",
    "# === 给出整体统计（可选：只对前10统计，或你也可以改成对全量统计）===\n",
    "def pctl(arr, p):\n",
    "    arr = sorted(arr)\n",
    "    if not arr:\n",
    "        return 0\n",
    "    k = (len(arr)-1) * p\n",
    "    f = math.floor(k)\n",
    "    c = math.ceil(k)\n",
    "    if f == c:\n",
    "        return arr[int(k)]\n",
    "    return arr[f] + (arr[c] - arr[f]) * (k - f)\n",
    "\n",
    "print(\"\\n—— 前 10 条统计 ——\")\n",
    "print(f\"均值: {mean(lens):.1f}\")\n",
    "print(f\"中位数: {median(lens):.1f}\")\n",
    "print(f\"P90: {pctl(lens, 0.90):.1f}\")\n",
    "print(f\"最大值: {max(lens)}\")\n",
    "print(f\"最小值: {min(lens)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f7e341c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/home/hli962/miniconda3/lib/python3.12/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "—— 全量 posts_cleaned token 统计 ——\n",
      "样本数: 48000\n",
      "均值: 237.1, 中位数: 239.0, P90: 284.0, P95: 307.0, 最小: 3, 最大: 1781\n",
      "\n",
      "建议 MAX_LEN: 320  (依据 P90≈284 推出)\n",
      "建议 batch_size: 12  (保持每步 token≈4096 → 当前≈3840)\n",
      "可选：设置梯度累积 grad_accum=2，使有效 tokens/step ≥ 之前水平。\n",
      "\n",
      "【对照】max_len=288: 建议 batch_size=14, tokens/step≈4032, 截断比例估计≈8.5%\n",
      "\n",
      "【对照】max_len=320: 建议 batch_size=12, tokens/step≈3840, 截断比例估计≈3.7%\n"
     ]
    }
   ],
   "source": [
    "import json, os, math, numpy as np\n",
    "from statistics import mean, median\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "# ===== 配置（按你当前训练的真实值填写）=====\n",
    "MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "FILE_PATH  = \"mbti_sample_with_all_views_pandora.json\"\n",
    "\n",
    "BASE_MAX_LEN   = 512   # 你当前/之前用的 max_length（若原来512就填512）\n",
    "BASE_BATCHSIZE = 8     # 你当前/之前用的 batch_size\n",
    "TARGET_TOKENS_PER_STEP = BASE_MAX_LEN * BASE_BATCHSIZE  # 保持每步 token 数不变\n",
    "\n",
    "# ===== 读取数据 & tokenizer =====\n",
    "assert os.path.exists(FILE_PATH), f\"文件不存在: {FILE_PATH}\"\n",
    "with open(FILE_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "def count_tokens(text: str) -> int:\n",
    "    enc = tokenizer(text, add_special_tokens=True, truncation=False)\n",
    "    return len(enc[\"input_ids\"])\n",
    "\n",
    "# ===== 全量长度统计 =====\n",
    "lengths = [count_tokens(x.get(\"posts_cleaned\", \"\")) for x in data]\n",
    "p90 = float(np.percentile(lengths, 90))\n",
    "p95 = float(np.percentile(lengths, 95))\n",
    "avg = mean(lengths)\n",
    "med = median(lengths)\n",
    "mx  = max(lengths)\n",
    "mn  = min(lengths)\n",
    "\n",
    "print(\"—— 全量 posts_cleaned token 统计 ——\")\n",
    "print(f\"样本数: {len(lengths)}\")\n",
    "print(f\"均值: {avg:.1f}, 中位数: {med:.1f}, P90: {p90:.1f}, P95: {p95:.1f}, 最小: {mn}, 最大: {mx}\")\n",
    "\n",
    "# ===== 建议的 MAX_LEN（P90 + 余量，并对齐到32）=====\n",
    "def round_up_to_multiple(x, base=32):\n",
    "    return int(base * math.ceil(x / base))\n",
    "\n",
    "suggest_len = round_up_to_multiple(p90 + 16, 32)  # 16 作为小余量\n",
    "# 你也可以限制范围避免过长/过短，比如：\n",
    "suggest_len = min(max(suggest_len, 224), 448)\n",
    "\n",
    "print(f\"\\n建议 MAX_LEN: {suggest_len}  (依据 P90≈{p90:.0f} 推出)\")\n",
    "\n",
    "# ===== 建议的 batch_size / 梯度累积 =====\n",
    "new_bs = max(1, TARGET_TOKENS_PER_STEP // suggest_len)\n",
    "tokens_per_step = new_bs * suggest_len\n",
    "print(f\"建议 batch_size: {new_bs}  (保持每步 token≈{TARGET_TOKENS_PER_STEP} → 当前≈{tokens_per_step})\")\n",
    "\n",
    "# 如果你希望“至少达到”目标 token/step，可以给出梯度累积建议：\n",
    "if tokens_per_step < TARGET_TOKENS_PER_STEP:\n",
    "    need = math.ceil(TARGET_TOKENS_PER_STEP / tokens_per_step)\n",
    "    print(f\"可选：设置梯度累积 grad_accum={need}，使有效 tokens/step ≥ 之前水平。\")\n",
    "\n",
    "# =====（可选）看看如果更激进用 288/320 会怎样 =====\n",
    "for test_len in [288, 320]:\n",
    "    test_bs = max(1, TARGET_TOKENS_PER_STEP // test_len)\n",
    "    print(f\"\\n【对照】max_len={test_len}: 建议 batch_size={test_bs}, \"\n",
    "          f\"tokens/step≈{test_bs*test_len}, 截断比例估计≈{sum(l>test_len for l in lengths)/len(lengths):.1%}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e931a44",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/transformers/convert_slow_tokenizer.py:564: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n",
      "[Train] Epoch 1/30: 100%|██████████| 10800/10800 [28:42<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 1 - Train Loss: 0.8944\n",
      "Validation | EI: 64.12% NS: 63.40% TF: 68.94% JP: 61.23% 4D: 20.48%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 20.48%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 2/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 2 - Train Loss: 0.8023\n",
      "Validation | EI: 68.58% NS: 66.88% TF: 71.19% JP: 65.77% 4D: 27.15%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 27.15%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 3/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 3 - Train Loss: 0.7292\n",
      "Validation | EI: 70.44% NS: 71.46% TF: 73.04% JP: 67.88% 4D: 31.92%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 31.92%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 4/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 4 - Train Loss: 0.6559\n",
      "Validation | EI: 71.69% NS: 72.08% TF: 73.77% JP: 68.81% 4D: 34.77%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 34.77%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 5/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 5 - Train Loss: 0.5815\n",
      "Validation | EI: 71.60% NS: 71.46% TF: 75.04% JP: 68.69% 4D: 35.25%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 35.25%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 6/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 6 - Train Loss: 0.5095\n",
      "Validation | EI: 72.73% NS: 73.75% TF: 72.56% JP: 69.04% 4D: 36.50%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 36.50%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 7/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 7 - Train Loss: 0.4485\n",
      "Validation | EI: 71.88% NS: 73.25% TF: 73.27% JP: 69.46% 4D: 36.31%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 8/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 8 - Train Loss: 0.4012\n",
      "Validation | EI: 71.56% NS: 73.15% TF: 71.44% JP: 70.79% 4D: 36.00%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 9/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 9 - Train Loss: 0.3647\n",
      "Validation | EI: 70.83% NS: 73.08% TF: 74.54% JP: 69.60% 4D: 35.38%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 10/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 10 - Train Loss: 0.3376\n",
      "Validation | EI: 70.75% NS: 73.67% TF: 74.44% JP: 69.19% 4D: 35.12%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 11/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 11 - Train Loss: 0.3187\n",
      "Validation | EI: 71.42% NS: 73.85% TF: 74.73% JP: 70.48% 4D: 37.54%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 37.54%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 12/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 12 - Train Loss: 0.3019\n",
      "Validation | EI: 71.58% NS: 73.19% TF: 73.85% JP: 69.62% 4D: 36.10%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 13/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 13 - Train Loss: 0.2909\n",
      "Validation | EI: 71.42% NS: 74.04% TF: 73.15% JP: 71.38% 4D: 36.48%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 14/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 14 - Train Loss: 0.2790\n",
      "Validation | EI: 70.71% NS: 73.71% TF: 73.92% JP: 70.25% 4D: 36.85%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 15/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 15 - Train Loss: 0.2714\n",
      "Validation | EI: 71.67% NS: 73.48% TF: 74.48% JP: 70.92% 4D: 36.98%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 16/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 16 - Train Loss: 0.2612\n",
      "Validation | EI: 72.96% NS: 74.54% TF: 74.77% JP: 71.23% 4D: 38.92%\n",
      "✅ Best model saved to best_deberta_supcon.pt (4D: 38.92%)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 17/30: 100%|██████████| 10800/10800 [28:41<00:00,  6.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch 17 - Train Loss: 0.2581\n",
      "Validation | EI: 72.15% NS: 73.85% TF: 74.33% JP: 70.44% 4D: 38.38%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Train] Epoch 18/30:  38%|███▊      | 4112/10800 [10:55<17:44,  6.28it/s]"
     ]
    }
   ],
   "source": [
    "import json\n",
    "from transformers import AutoTokenizer\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# ===== 1. 读取数据 =====\n",
    "with open(\"mbti_sample_with_all_views_pandora.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# ===== 2. 划分训练/验证集 =====\n",
    "train_data, val_data = train_test_split(data, test_size=0.1, random_state=42)\n",
    "\n",
    "# ===== 3. 初始化Tokenizer和Dataset =====\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "\n",
    "train_dataset = JointMBTIDataset(train_data, tokenizer, max_length=MAX_LEN, drop_prob=0.1)\n",
    "val_dataset = JointMBTIDataset(val_data, tokenizer, max_length=MAX_LEN, drop_prob=0.0)  # 验证集不drop\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "# ===== 4. 初始化模型和优化器 =====\n",
    "model = JointMBTIModel(encoder_name=MODEL_NAME)\n",
    "optimizer = AdamW(model.parameters(), lr=LR)\n",
    "\n",
    "# ===== 5. 开始训练 =====\n",
    "train_model(model, train_loader, val_loader, optimizer, DEVICE, epochs=EPOCHS, save_path=SAVE_PATH)\n",
    "\n",
    "# ===== 6. 加载最佳模型做推理 =====\n",
    "model.load_state_dict(torch.load(SAVE_PATH))\n",
    "model.eval()\n",
    "\n",
    "# 推理示例\n",
    "sample = val_data[0]\n",
    "encoding = tokenizer(\n",
    "    sample[\"posts_cleaned\"],\n",
    "    padding=\"max_length\",\n",
    "    truncation=True,\n",
    "    max_length=MAX_LEN,\n",
    "    return_tensors=\"pt\"\n",
    ")\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs = model(encoding[\"input_ids\"].to(DEVICE), encoding[\"attention_mask\"].to(DEVICE))\n",
    "    pred_ei = torch.sigmoid(outputs[\"ei\"]).round().item()\n",
    "    pred_ns = torch.sigmoid(outputs[\"ns\"]).round().item()\n",
    "    pred_tf = torch.sigmoid(outputs[\"tf\"]).round().item()\n",
    "    pred_jp = torch.sigmoid(outputs[\"jp\"]).round().item()\n",
    "\n",
    "print(\"原标签:\", sample[\"type\"])\n",
    "print(\"预测四维:\", pred_ei, pred_ns, pred_tf, pred_jp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
