{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a3af157-1d42-46dc-aff9-223a066ea1ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import nltk\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "\n",
    "# 下载必要资源\n",
    "# nltk.download(\"punkt\")\n",
    "# nltk.download(\"stopwords\")\n",
    "# nltk.download(\"wordnet\")\n",
    "\n",
    "lemmatizer = WordNetLemmatizer()\n",
    "stop_words = set(stopwords.words(\"english\"))\n",
    "\n",
    "# MBTI 类型关键词（大写匹配）\n",
    "mbti_types = {\n",
    "    'INTJ', 'INTP', 'ENTJ', 'ENTP',\n",
    "    'INFJ', 'INFP', 'ENFJ', 'ENFP',\n",
    "    'ISTJ', 'ISFJ', 'ESTJ', 'ESFJ',\n",
    "    'ISTP', 'ISFP', 'ESTP', 'ESFP'\n",
    "}\n",
    "\n",
    "# 深度清洗：用于 posts_cleaned\n",
    "def clean_text(text):\n",
    "    text = text.lower()\n",
    "    text = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', text)\n",
    "    text = text.replace('|', ' ')\n",
    "    text = re.sub(r'[^a-z\\s]', ' ', text)\n",
    "    tokens = nltk.word_tokenize(text)\n",
    "    cleaned = [\n",
    "        lemmatizer.lemmatize(token)\n",
    "        for token in tokens\n",
    "        if token not in stop_words and len(token) > 2\n",
    "    ]\n",
    "    return ' '.join(cleaned)\n",
    "\n",
    "# 轻度清洗：用于 post_casebank（仅去链接）\n",
    "def clean_casebank_keep_format(text):\n",
    "    text = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', text)\n",
    "    return re.sub(r'\\s+', ' ', text).strip()\n",
    "\n",
    "# 主清洗逻辑\n",
    "def preprocess_mbti_json(input_json_path, output_json_path=None, output_csv_path=None):\n",
    "    with open(input_json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    cleaned_data = []\n",
    "    MIN_TOKEN_LEN = 10  # 词数下限\n",
    "\n",
    "    for item in tqdm(data, desc=\"🧹 Cleaning MBTI posts\"):\n",
    "        raw_posts = item.get(\"posts\", \"\")\n",
    "        segments = raw_posts.strip(\"'\").split(\"|||\")\n",
    "\n",
    "        cleaned_segments = []\n",
    "        casebank_segments = []\n",
    "\n",
    "        for seg in segments:\n",
    "            seg = seg.strip()\n",
    "            has_url = 'http' in seg or 'www.' in seg\n",
    "            token_len = len(nltk.word_tokenize(seg))\n",
    "            contains_mbti = any(t in seg.upper() for t in mbti_types)\n",
    "\n",
    "            # 丢弃规则统一：短 + URL + MBTI 标签 → 丢\n",
    "            keep = not has_url and token_len > MIN_TOKEN_LEN and not contains_mbti\n",
    "\n",
    "            if keep:\n",
    "                cleaned = clean_text(seg)\n",
    "                casebank_cleaned = clean_casebank_keep_format(seg)\n",
    "\n",
    "                if cleaned:\n",
    "                    cleaned_segments.append(cleaned)\n",
    "                if casebank_cleaned:\n",
    "                    casebank_segments.append(casebank_cleaned)\n",
    "\n",
    "        cleaned_data.append({\n",
    "            \"type\": item.get(\"type\"),\n",
    "            \"posts\": raw_posts,\n",
    "            \"posts_cleaned\":item.get(\"posts_cleaned\"),\n",
    "            \"post_casebank\": \" \".join(casebank_segments)\n",
    "        })\n",
    "\n",
    "    # 保存 JSON\n",
    "    if output_json_path:\n",
    "        os.makedirs(os.path.dirname(output_json_path) or \".\", exist_ok=True)\n",
    "        with open(output_json_path, \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(cleaned_data, f, ensure_ascii=False, indent=2)\n",
    "        print(f\"✅ JSON saved to: {output_json_path}\")\n",
    "\n",
    "    # 提示 CSV 输出\n",
    "    if output_csv_path:\n",
    "        print(\"⚠️ CSV建议只导出字段：type, posts_cleaned, post_casebank\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c23163b-6155-4418-9ff9-2b58e97b98d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_mbti_json(\n",
    "    input_json_path=\"extended_mbti_dataset_v17.json\",\n",
    "    output_csv_path=\"extended_mbti_v17_cleaned.csv\",\n",
    "    output_json_path=\"casebank_mbti.json\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b017ebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#lora 清理部分\n",
    "import json, re, hashlib\n",
    "from collections import Counter\n",
    "# from typing import Optional  # 如果是 py3.8/3.9，打开并把下面的返回类型改成 Optional[dict]\n",
    "\n",
    "MBTI_16 = {\"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "           \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"}\n",
    "\n",
    "def light_clean(x: str) -> str:\n",
    "    if not x: return \"\"\n",
    "    x = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', x)\n",
    "    x = x.replace('|',' ')\n",
    "    x = re.sub(r'\\s+', ' ', x).strip()\n",
    "    return x\n",
    "\n",
    "def build_case_row(r: dict, max_chars: int = 1200) -> dict:  # 如果低版本，写 Optional[dict]\n",
    "    t = str(r.get(\"type\",\"\")).upper().strip()\n",
    "    if t not in MBTI_16: \n",
    "        return None\n",
    "\n",
    "    base = light_clean(r.get(\"posts_cleaned\") or r.get(\"posts\") or \"\")\n",
    "    sem  = light_clean(r.get(\"semantic_view\",\"\"))\n",
    "    sen  = light_clean(r.get(\"sentiment_view\",\"\"))\n",
    "    lin  = light_clean(r.get(\"linguistic_view\",\"\"))\n",
    "\n",
    "    if not base:\n",
    "        return None\n",
    "\n",
    "    meta = []\n",
    "    if sem: meta.append(f\"[Semantic] {sem}\")\n",
    "    if sen: meta.append(f\"[Sentiment] {sen}\")\n",
    "    if lin: meta.append(f\"[Linguistic] {lin}\")\n",
    "    post_casebank = base if not meta else f\"{base} \" + \" \".join(meta)\n",
    "\n",
    "    if len(post_casebank) > max_chars:\n",
    "        post_casebank = post_casebank[:max_chars].rstrip() + \"…\"\n",
    "\n",
    "    return {\"type\": t, \"post_casebank\": post_casebank, \"embed_text\": base}\n",
    "\n",
    "def make_casebank(in_file: str, out_file: str, max_chars: int = 1200, dedup: bool = True):\n",
    "    data = json.load(open(in_file, \"r\", encoding=\"utf-8\"))\n",
    "    out, seen = [], set()\n",
    "\n",
    "    for r in data:\n",
    "        row = build_case_row(r, max_chars=max_chars)\n",
    "        if not row: \n",
    "            continue\n",
    "        if dedup:\n",
    "            h = hashlib.md5(row[\"embed_text\"].encode(\"utf-8\")).hexdigest()\n",
    "            if h in seen:\n",
    "                continue\n",
    "            seen.add(h)\n",
    "        out.append(row)\n",
    "\n",
    "    json.dump(out, open(out_file, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "    cnt = Counter([x[\"type\"] for x in out])\n",
    "    total = len(out)\n",
    "    print(f\"✅ Casebank 完成：{total} 条 → {out_file}\")\n",
    "    print(\"📊 类型分布：\")\n",
    "    for k in sorted(cnt):\n",
    "        print(f\"  {k}: {cnt[k]} ({cnt[k]/total:.2%})\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecb185ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_PATH  = \"pandora_eval_80.json\"              # 你的文件\n",
    "OUTPUT_PATH = \"casebank_pandora_eval_80.json\"     # 输出\n",
    "\n",
    "make_casebank(INPUT_PATH, OUTPUT_PATH, max_chars=1200, dedup=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3d39bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification\n",
    "from peft import PeftModel\n",
    "\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b-split_kaggle_ckpt\"   # ✅ 指向包含 adapter_config.json 的目录\n",
    "\n",
    "assert os.path.exists(os.path.join(ADAPTER_DIR, \"adapter_config.json\")), \"路径不对，找不到 adapter_config.json\"\n",
    "\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "\n",
    "cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "cfg.num_labels = 16\n",
    "\n",
    "base = AutoModelForSequenceClassification.from_pretrained(\n",
    "    BASE_MODEL, config=cfg, trust_remote_code=True,\n",
    "    device_map={\"\": \"cuda:0\"}, low_cpu_mem_usage=True\n",
    ")\n",
    "\n",
    "model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False).eval()\n",
    "print(\"✅ LoRA 适配器加载成功\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcc4e1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract_embeddings_qwen_meanpool.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
    "from peft import PeftModel\n",
    "\n",
    "# ===== 路径配置 =====\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b-split_kaggle_ckpt\"\n",
    "INPUT_FILE  = \"train.json\"\n",
    "OUTPUT_FILE = \"casebank_A_train_80_with_embeddings.json\"\n",
    "\n",
    "# 设备\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device   = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
    "device_map = {\"\": \"cuda:0\"} if use_cuda else None\n",
    "\n",
    "# ===== tokenizer & model =====\n",
    "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\"\n",
    "\n",
    "config = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "config.num_labels = 16\n",
    "\n",
    "base = AutoModelForSequenceClassification.from_pretrained(\n",
    "    BASE_MODEL,\n",
    "    config=config,\n",
    "    trust_remote_code=True,\n",
    "    device_map=device_map,\n",
    "    low_cpu_mem_usage=True,\n",
    ")\n",
    "model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False)\n",
    "model.eval().to(device)\n",
    "\n",
    "# ===== 句向量提取（masked mean pooling；推荐） =====\n",
    "@torch.no_grad()\n",
    "def get_text_embedding(text: str, max_len: int = 512) -> list[float]:\n",
    "    ins = tokenizer(\n",
    "        text,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "        padding=True,\n",
    "        max_length=max_len\n",
    "    ).to(device)\n",
    "\n",
    "    out = model.base_model.model(\n",
    "        **ins, output_hidden_states=True, use_cache=False, return_dict=True\n",
    "    )\n",
    "    last_hidden = out.hidden_states[-1]          # (B, T, H)\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1)   # (B, T, 1)\n",
    "\n",
    "    # masked mean pooling\n",
    "    masked = last_hidden * mask                  # zero out pads\n",
    "    sum_hidden = masked.sum(dim=1)               # (B, H)\n",
    "    lengths = mask.sum(dim=1).clamp(min=1)       # (B, 1)\n",
    "    emb = sum_hidden / lengths                   # (B, H)\n",
    "\n",
    "    # L2 归一化（和 FAISS 里一致：余弦=内积）\n",
    "    emb = F.normalize(emb, p=2, dim=1)           # (B, H)\n",
    "    return emb.squeeze(0).cpu().tolist()\n",
    "\n",
    "# （可选）最后 token pooling 版本：\n",
    "# def get_text_embedding(text: str, max_len: int = 512) -> list[float]:\n",
    "#     ins = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=max_len).to(device)\n",
    "#     out = model.base_model.model(**ins, output_hidden_states=True, use_cache=False, return_dict=True)\n",
    "#     last_hidden = out.hidden_states[-1]                # (B, T, H)\n",
    "#     idx = ins[\"attention_mask\"].sum(dim=1) - 1         # (B,)\n",
    "#     emb = last_hidden[torch.arange(last_hidden.size(0)), idx]  # (B, H)\n",
    "#     emb = F.normalize(emb, p=2, dim=1)\n",
    "#     return emb.squeeze(0).cpu().tolist()\n",
    "\n",
    "# ===== 处理 casebank =====\n",
    "with open(INPUT_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "output = []\n",
    "for item in tqdm(data, desc=\"Extracting sentence embeddings\"):\n",
    "    # 仍使用你构建 casebank 时的展示字段\n",
    "    text = item.get(\"post_casebank\") or item.get(\"embed_text\") or item.get(\"posts_cleaned\") or item.get(\"posts\") or \"\"\n",
    "    if not text.strip():\n",
    "        continue\n",
    "    emb = get_text_embedding(text)\n",
    "    item[\"embedding\"] = emb\n",
    "    output.append(item)\n",
    "\n",
    "with open(OUTPUT_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(output, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ 完成！结果已保存到 {OUTPUT_FILE}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45305b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build_faiss_from_casebank.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, re\n",
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "DATA_PATH       = \"casebank_A_train_80_with_embeddings.json\"\n",
    "FAISS_INDEX_PATH= \"casebank_A_qwen_cls_cosine.index\"\n",
    "ID_MAP_PATH     = \"casebank_A_qwen_cls_idmap.json\"\n",
    "\n",
    "def light_clean(x: str) -> str:\n",
    "    if not x: return \"\"\n",
    "    x = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', x)\n",
    "    x = x.replace('|',' ')\n",
    "    x = re.sub(r'\\s+', ' ', x).strip()\n",
    "    return x\n",
    "\n",
    "def ensure_post_casebank(item: dict, max_chars: int = 1200) -> str:\n",
    "    \"\"\"万一上一步没写入 post_casebank，这里兜底再构建一次。\"\"\"\n",
    "    if \"post_casebank\" in item and item[\"post_casebank\"].strip():\n",
    "        return item[\"post_casebank\"]\n",
    "\n",
    "    base = light_clean(item.get(\"embed_text\") or item.get(\"posts_cleaned\") or item.get(\"posts\") or \"\")\n",
    "    sem  = light_clean(item.get(\"semantic_view\",\"\"))\n",
    "    sen  = light_clean(item.get(\"sentiment_view\",\"\"))\n",
    "    lin  = light_clean(item.get(\"linguistic_view\",\"\"))\n",
    "    meta = []\n",
    "    if sem: meta.append(f\"[Semantic] {sem}\")\n",
    "    if sen: meta.append(f\"[Sentiment] {sen}\")\n",
    "    if lin: meta.append(f\"[Linguistic] {lin}\")\n",
    "    post_casebank = base if not meta else f\"{base} \" + \" \".join(meta)\n",
    "    if len(post_casebank) > max_chars:\n",
    "        post_casebank = post_casebank[:max_chars].rstrip() + \"…\"\n",
    "    return post_casebank\n",
    "\n",
    "# === 加载 ===\n",
    "with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    raw = json.load(f)\n",
    "\n",
    "embs, id_map = [], {}\n",
    "for idx, item in enumerate(raw):\n",
    "    emb = np.array(item[\"embedding\"], dtype=np.float32)\n",
    "    embs.append(emb)\n",
    "    id_map[idx] = {\n",
    "        \"post_casebank\": ensure_post_casebank(item),\n",
    "        \"type\": item.get(\"type\", \"\")\n",
    "    }\n",
    "\n",
    "embs = np.vstack(embs).astype(\"float32\")\n",
    "faiss.normalize_L2(embs)                # 归一化使内积=余弦\n",
    "\n",
    "dim = embs.shape[1]\n",
    "index = faiss.IndexFlatIP(dim)\n",
    "index.add(embs)\n",
    "\n",
    "faiss.write_index(index, FAISS_INDEX_PATH)\n",
    "with open(ID_MAP_PATH, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(id_map, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量库已构建:\", FAISS_INDEX_PATH)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4b54b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sanity_check_and_search.py\n",
    "import json, faiss, numpy as np\n",
    "\n",
    "FAISS_INDEX_PATH = \"casebank_A_qwen_cls_cosine.index\"\n",
    "ID_MAP_PATH      = \"casebank_A_qwen_cls_idmap.json\"\n",
    "EMB_DATA_PATH    = \"casebank_A_train_80_with_embeddings.json\"  # 就是你建库时的那个 JSON\n",
    "\n",
    "# 1) 加载索引和 id_map\n",
    "index  = faiss.read_index(FAISS_INDEX_PATH)\n",
    "with open(ID_MAP_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    id_map = json.load(f)\n",
    "\n",
    "print(\"index.ntotal =\", index.ntotal, \" | id_map size =\", len(id_map))\n",
    "\n",
    "# 2) 读取同一个 embeddings 数据文件，拿一个向量当 query\n",
    "with open(EMB_DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "assert len(data) > 0, \"embeddings 数据为空\"\n",
    "k = 10 if len(data) > 10 else 0                    # 选第 k 个样本\n",
    "q = np.array(data[k][\"embedding\"], dtype=np.float32).reshape(1, -1)\n",
    "\n",
    "# 如果你构建索引前做了 L2 归一化，这里最好再归一化一次（保险）\n",
    "faiss.normalize_L2(q)\n",
    "\n",
    "# 3) 检索 top-5\n",
    "D, I = index.search(q, 5)\n",
    "\n",
    "# 4) 打印结果\n",
    "for s, idx in zip(D[0], I[0]):\n",
    "    info = id_map[str(int(idx))]                   # id_map 的 key 是字符串\n",
    "    print(f\"{s:.4f}\", info[\"type\"], \"|\", info[\"post_casebank\"][:80], \"...\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119dd2ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# encode_and_retrieve_topk.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, faiss, numpy as np, torch\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
    "from peft import PeftModel\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# ========= 路径配置 =========\n",
    "# 已建好的向量库（由 casebank_*_with_embeddings.json 构建）\n",
    "FAISS_INDEX_PATH = \"casebank_A_qwen_cls_cosine.index\"\n",
    "ID_MAP_PATH      = \"casebank_A_qwen_cls_idmap.json\"\n",
    "\n",
    "# 待向量化并检索的测试集\n",
    "TEST_INPUT          = \"picked_balanced_around30.json\"\n",
    "TEST_WITH_EMB_OUT   = \"A_test_with_embeddings_final1.json\"\n",
    "TOPK_RESULTS_OUT    = \"A_test_top1_final1.json\"\n",
    "\n",
    "# Qwen + LoRA（与建库时完全一致）\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b-split_kaggle_ckpt\"\n",
    "\n",
    "TOPK = 1\n",
    "MAX_LEN = 512\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# ========= 加载索引 =========\n",
    "index = faiss.read_index(FAISS_INDEX_PATH)\n",
    "id_map = json.load(open(ID_MAP_PATH, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "# ========= 加载模型（与建库时一致）=========\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "tok.padding_side = \"right\"\n",
    "\n",
    "cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "cfg.num_labels = 16\n",
    "\n",
    "base = AutoModelForSequenceClassification.from_pretrained(\n",
    "    BASE_MODEL, config=cfg, trust_remote_code=True,\n",
    "    device_map={\"\": device}, low_cpu_mem_usage=True\n",
    ")\n",
    "model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False).eval().to(device)\n",
    "\n",
    "# ========= 工具函数 =========\n",
    "def pick_text(item: dict) -> str:\n",
    "    \"\"\"按优先级选一个可用于编码的文本字段。按你的数据结构需要可调整顺序。\"\"\"\n",
    "    for key in [\"posts_cleaned_for_embedding\", \"embed_text\", \"post_casebank\",\n",
    "                \"posts_cleaned\", \"posts\", \"text\", \"query_text\"]:\n",
    "        v = item.get(key)\n",
    "        if isinstance(v, str) and v.strip():\n",
    "            return v\n",
    "    return \"\"\n",
    "\n",
    "@torch.no_grad()\n",
    "def encode_text_to_vec(text: str) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    用 masked mean pooling 提句向量（整段文本表征）并做 L2 归一化，\n",
    "    与索引端 IndexFlatIP（内积≈余弦）一致。\n",
    "    \"\"\"\n",
    "    ins = tok(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=MAX_LEN).to(device)\n",
    "    out = model.base_model.model(**ins, output_hidden_states=True, use_cache=False, return_dict=True)\n",
    "    last_hidden = out.hidden_states[-1]                 # (B, T, H)\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1).float()  # (B, T, 1)\n",
    "\n",
    "    masked = last_hidden * mask                         # zero out pads\n",
    "    sum_hidden = masked.sum(dim=1)                      # (B, H)\n",
    "    lengths = mask.sum(dim=1).clamp(min=1.0)            # (B, 1)\n",
    "    vec = sum_hidden / lengths                          # (B, H)\n",
    "\n",
    "    vec = F.normalize(vec, p=2, dim=1)                  # L2 归一化\n",
    "    return vec.squeeze(0).float().cpu().numpy()         # (H,)\n",
    "\n",
    "def search_topk(vec: np.ndarray, k=TOPK):\n",
    "    D, I = index.search(vec.reshape(1, -1), k)\n",
    "    return [{\"score\": float(s), **id_map[str(int(i))]} for s, i in zip(D[0], I[0])]\n",
    "\n",
    "def filter_self_hits(query_text: str, hits: list, k: int):\n",
    "    \"\"\"可选：过滤与查询文本完全相同的返回，避免“命中自己”\"\"\"\n",
    "    def norm(s): return \" \".join((s or \"\").lower().split())\n",
    "    qt = norm(query_text)\n",
    "    keep = [h for h in hits if norm(h.get(\"post_casebank\", \"\")) != qt]\n",
    "    return keep[:k]\n",
    "\n",
    "# ========= 读取测试集、编码、检索 =========\n",
    "test = json.load(open(TEST_INPUT, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "with_emb, results = [], []\n",
    "for it in tqdm(test, desc=\"Encode & retrieve\"):\n",
    "    text = pick_text(it)\n",
    "    if not text:\n",
    "        # 没有可用文本就记空\n",
    "        it_out = dict(it); it_out[\"embedding\"] = None\n",
    "        with_emb.append(it_out)\n",
    "        results.append({\"type\": it.get(\"type\",\"\"), \"query_text\": \"\", \"topk_cases\": []})\n",
    "        continue\n",
    "\n",
    "    emb = encode_text_to_vec(text)              # (H,)\n",
    "    it_out = dict(it); it_out[\"embedding\"] = emb.tolist()\n",
    "    with_emb.append(it_out)\n",
    "\n",
    "    topk = search_topk(emb, k=TOPK)\n",
    "    topk = filter_self_hits(text, topk, TOPK)   # 可关掉：如果不担心命中自己\n",
    "    results.append({\n",
    "        \"type\": it.get(\"type\",\"\"),\n",
    "        \"query_text\": text,\n",
    "        \"topk_cases\": topk\n",
    "    })\n",
    "\n",
    "# ========= 落盘 =========\n",
    "json.dump(with_emb, open(TEST_WITH_EMB_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "json.dump(results, open(TOPK_RESULTS_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量化完成：\", TEST_WITH_EMB_OUT)\n",
    "print(\"✅ TopK 检索完成：\", TOPK_RESULTS_OUT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84a7e483-c975-471a-b6aa-ee0725a7c36e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ===== 模型配置 =====\n",
    "MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "MODEL_PATH = \"best_fem_deberta.pt\"\n",
    "USE_CLS = True\n",
    "MIN_WORDS = 5 # 最小词数要求\n",
    "\n",
    "# ===== 加载模型和 tokenizer =====\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = AutoModel.from_pretrained(MODEL_NAME)\n",
    "state_dict = torch.load(MODEL_PATH, map_location=torch.device(\"cpu\"))\n",
    "model.load_state_dict(state_dict, strict=False)\n",
    "model.eval()\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# ===== 提取嵌入向量函数 =====\n",
    "def get_embedding(text):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=512)\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        if USE_CLS:\n",
    "            return outputs.last_hidden_state[:, 0, :].squeeze().cpu().tolist()\n",
    "        else:\n",
    "            return torch.mean(outputs.last_hidden_state, dim=1).squeeze().cpu().tolist()\n",
    "\n",
    "# ===== 读取原始 casebank 数据 =====\n",
    "with open(\"casebank_mbti.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "output = []\n",
    "kept_count = 0\n",
    "filtered_count = 0\n",
    "\n",
    "for item in tqdm(data, desc=\"🔍 Extracting embeddings\"):\n",
    "    post_text = item.get(\"post_casebank\", \"\").strip()\n",
    "    mbti_type = item.get(\"type\", \"\")\n",
    "\n",
    "    # ===== 过滤短文本 =====\n",
    "    word_count = len(post_text.split())\n",
    "    if word_count < MIN_WORDS:\n",
    "        filtered_count += 1\n",
    "        continue\n",
    "\n",
    "    emb = get_embedding(post_text)\n",
    "    output.append({\n",
    "        \"embedding\": emb,\n",
    "        \"post_casebank\": post_text,\n",
    "        \"type\": mbti_type\n",
    "    })\n",
    "    kept_count += 1\n",
    "\n",
    "# ===== 保存 =====\n",
    "output_file = \"mbti_embeddings_from_custom_model.json\"\n",
    "with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(output, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "# ===== 打印统计结果 =====\n",
    "print(\"✅ 完成！结果已保存至：\", output_file)\n",
    "print(f\"📊 共处理样本数：{len(data)}\")\n",
    "print(f\"✅ 保留样本数：{kept_count}\")\n",
    "print(f\"🚫 筛除过短样本数：{filtered_count}（<{MIN_WORDS}词）\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10bca73f-5a51-4e12-823a-daa20b58e943",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "# ========== 配置 ==========\n",
    "DATA_PATH = \"mbti_embeddings_from_custom_model.json\"\n",
    "FAISS_INDEX_PATH = \"mbti_faiss_cosine.index\"\n",
    "ID_MAP_PATH = \"mbti_idmap_cosine.json\"\n",
    "\n",
    "# ========== 加载嵌入 ==========\n",
    "with open(DATA_PATH, \"r\", encoding=\"utf-8\") as f:\n",
    "    raw_data = json.load(f)\n",
    "\n",
    "embeddings = []\n",
    "id_map = {}\n",
    "\n",
    "for idx, item in enumerate(raw_data):\n",
    "    emb = np.array(item[\"embedding\"], dtype=np.float32)\n",
    "    embeddings.append(emb)\n",
    "    id_map[idx] = {\n",
    "        \"post_casebank\": item[\"post_casebank\"],\n",
    "        \"type\": item[\"type\"]\n",
    "    }\n",
    "\n",
    "# ========== 归一化 + 构建余弦相似度索引 ==========\n",
    "embeddings = np.vstack(embeddings)\n",
    "faiss.normalize_L2(embeddings)  # 👈 归一化后内积就等效于余弦相似度\n",
    "\n",
    "dimension = embeddings.shape[1]\n",
    "index = faiss.IndexFlatIP(dimension)  # 使用内积（归一化后等价于余弦）\n",
    "index.add(embeddings)\n",
    "\n",
    "# ========== 保存 ==========\n",
    "faiss.write_index(index, FAISS_INDEX_PATH)\n",
    "with open(ID_MAP_PATH, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(id_map, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量数据库（Cosine 相似度）构建完毕！\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce251e6e-f08f-4f89-bcda-ee33240923c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import nltk\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "from tqdm import tqdm\n",
    "\n",
    "# 如果第一次运行，需要取消注释：\n",
    "# nltk.download(\"punkt\")\n",
    "# nltk.download(\"stopwords\")\n",
    "# nltk.download(\"wordnet\")\n",
    "\n",
    "stop_words = set(stopwords.words(\"english\"))\n",
    "lemmatizer = WordNetLemmatizer()\n",
    "\n",
    "# ✅ 嵌入模型清洗：标准 NLP 清洗\n",
    "def clean_for_embedding(text):\n",
    "    text = text.lower()\n",
    "    text = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', text)\n",
    "    text = text.replace('|', ' ')\n",
    "    text = re.sub(r'[^a-z\\s]', ' ', text)\n",
    "    tokens = nltk.word_tokenize(text)\n",
    "    cleaned = [\n",
    "        lemmatizer.lemmatize(token)\n",
    "        for token in tokens\n",
    "        if token not in stop_words and len(token) > 2\n",
    "    ]\n",
    "    return ' '.join(cleaned)\n",
    "\n",
    "# ✅ 大模型清洗：只去 URL 和 |\n",
    "def clean_for_llm(text):\n",
    "    text = re.sub(r'(https?://\\S+|www\\.\\S+)', ' ', text)\n",
    "    text = text.replace('|', ' ')\n",
    "    return text.strip()\n",
    "\n",
    "# 主处理函数\n",
    "def preprocess_dual_clean(input_file, output_file):\n",
    "    with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        raw_data = json.load(f)\n",
    "\n",
    "    cleaned_data = []\n",
    "    for item in tqdm(raw_data, desc=\"Cleaning Dual Versions\"):\n",
    "        post = item[\"posts\"]\n",
    "        cleaned_data.append({\n",
    "            \"type\": item[\"type\"],\n",
    "            \"posts\": post,\n",
    "            \"posts_cleaned_for_embedding\": clean_for_embedding(post),\n",
    "            \"posts_cleaned_for_llm\": clean_for_llm(post)\n",
    "        })\n",
    "\n",
    "    with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(cleaned_data, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    print(f\"✅ 双重清洗完成，已保存至 {output_file}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "265a3b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_dual_clean(\"filtered_processed_comments_cleaned_random50.json\", \"pandora_cleaned.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f42a84-f3b0-46ab-af13-c954bf8dcf54",
   "metadata": {},
   "outputs": [],
   "source": [
    "#验证数据做embedding\n",
    "import json\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from tqdm import tqdm\n",
    "\n",
    "MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "MODEL_PATH = \"best_fem_deberta.pt\"\n",
    "USE_CLS = True\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = AutoModel.from_pretrained(MODEL_NAME)\n",
    "state_dict = torch.load(MODEL_PATH, map_location=torch.device(\"cpu\"))\n",
    "model.load_state_dict(state_dict, strict=False)\n",
    "model.eval()\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "def get_embedding(text):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=512)\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        return outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()\n",
    "\n",
    "# === 读取清洗后的验证集 ===\n",
    "with open(\"pandora_cleaned.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "\n",
    "results = []\n",
    "for item in tqdm(data, desc=\"Embedding\"):\n",
    "    emb = get_embedding(item[\"posts_cleaned_for_embedding\"])\n",
    "    results.append({\n",
    "        \"type\": item[\"type\"],\n",
    "        \"embedding\": emb.tolist(),\n",
    "        \"posts_cleaned_for_llm\": item[\"posts_cleaned_for_llm\"],\n",
    "        \"posts_cleaned_for_embedding\": item[\"posts_cleaned_for_embedding\"]\n",
    "    })\n",
    "\n",
    "with open(\"pandora_with_embeddings.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(results, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量化完成！\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e19f16-40a5-49bb-a7bd-02e8a5f2e7bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "#对验证数据进行向量相似度检索（Top-K）并准备给 LLM 使用的对话上下文结构\n",
    "\n",
    "import json\n",
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "# 读取 FAISS 和 ID 映射\n",
    "index = faiss.read_index(\"mbti_faiss.index\")\n",
    "with open(\"mbti_idmap.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    id_map = json.load(f)\n",
    "\n",
    "def retrieve_topk(query_embedding, k=5):\n",
    "    query_embedding = np.array(query_embedding).astype(np.float32).reshape(1, -1)\n",
    "    D, I = index.search(query_embedding, k)\n",
    "    return [id_map[str(i)] for i in I[0]]\n",
    "\n",
    "# 读取验证集带 embedding 的数据\n",
    "with open(\"pandora_with_embeddings.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    verification_data = json.load(f)\n",
    "\n",
    "retrieved_results = []\n",
    "\n",
    "for item in verification_data:\n",
    "    query_emb = item[\"embedding\"]\n",
    "    topk_cases = retrieve_topk(query_emb, k=3)\n",
    "    retrieved_results.append({\n",
    "        \"type\": item[\"type\"],\n",
    "        \"posts_cleaned_for_llm\": item[\"posts_cleaned_for_llm\"],\n",
    "        \"posts_cleaned_for_embedding\": item[\"posts_cleaned_for_embedding\"],\n",
    "        \"topk_cases\": topk_cases\n",
    "    })\n",
    "\n",
    "with open(\"retrieved_results_for_llm_pandora.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(retrieved_results, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 检索完成！结果保存在 retrieved_results_for_llm_pandora.json\")\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dec4344f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "path = \"retrieved_results_for_llm_pandora.json\"\n",
    "\n",
    "with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)  # 小心⚠️：大文件容易爆内存\n",
    "\n",
    "# 只看前 5 条\n",
    "for i, item in enumerate(data[:5]):\n",
    "    print(f\"\\n=== 第 {i+1} 条样本 ===\")\n",
    "    print(\"Type:\", item[\"type\"])\n",
    "    print(\"Query Post:\", item[\"posts_cleaned_for_llm\"][:300], \"...\")\n",
    "    print(\"Top-K types:\", [x[\"type\"] for x in item[\"topk_cases\"]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fddac7c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict, Counter\n",
    "import json, re, requests\n",
    "from tqdm import tqdm\n",
    "\n",
    "# === Extract final MBTI type from LLM output ===\n",
    "def extract_final_type(text):\n",
    "    match = re.search(r'(Final Type|最终类型)[:：]?\\s*([IEie][NSns][FTft][JPjp])\\b', text)\n",
    "    if match:\n",
    "        return match.group(2).upper()\n",
    "    return \"UNKNOWN\"\n",
    "\n",
    "# === Build the prompt ===\n",
    "def build_prompt(user_post, topk_cases):\n",
    "    examples_text = \"\\n\\n\".join([\n",
    "        f\"[Reference Example {i+1}]\\nPost Content: {ex['post_casebank']}\\nMBTI Type: {ex['type']}\" \n",
    "        for i, ex in enumerate(topk_cases)\n",
    "    ])\n",
    "    prompt = f\"\"\"\n",
    "You are an expert in MBTI personality typing and linguistic style analysis.\n",
    "\n",
    "[User Post]\n",
    "{user_post}\n",
    "\n",
    "[Reference Examples]\n",
    "{examples_text}\n",
    "\n",
    "---\n",
    "1. Final Type: ______\n",
    "2. Analyze the writing style, tone, logicality, and emotionality.\n",
    "3. Compare it with each reference example and explain similarities.\n",
    "4. Conclude with the most likely MBTI type.\n",
    "\"\"\"\n",
    "    return prompt\n",
    "\n",
    "# === LLM API wrapper ===\n",
    "class LLM_API_Wrapper:\n",
    "    def __init__(self, model, api_key):\n",
    "        self.model = model\n",
    "        self.api_key = api_key\n",
    "\n",
    "    def call_api(self, prompt_text: str, n: int = 1):\n",
    "        url = \"\"\n",
    "        payload = {\n",
    "            \"model\": self.model,\n",
    "            \"temperature\": 0.8,\n",
    "            \"n\": n,\n",
    "            \"messages\": [{\"role\": \"user\", \"content\": prompt_text}],\n",
    "            \"max_completion_tokens\": 512,\n",
    "            \"stream\": False\n",
    "        }\n",
    "        headers = {\"Content-Type\": \"application/json\",\"Authorization\": f\"Bearer {self.api_key}\"}\n",
    "        try:\n",
    "            r = requests.post(url, headers=headers, json=payload, timeout=60)\n",
    "            r.raise_for_status()\n",
    "            result = r.json()\n",
    "            return [c[\"message\"][\"content\"] for c in result.get(\"choices\", [])]\n",
    "        except Exception as e:\n",
    "            print(\"❌ Request failed:\", e)\n",
    "            return []\n",
    "\n",
    "# === Call until we collect enough valid responses ===\n",
    "def robust_llm_call(llm, prompt, vote_times=40, n_valid=20, max_total_requests=20):\n",
    "    all_valid_responses, total_requests = [], 0\n",
    "    while len(all_valid_responses) < n_valid and total_requests < max_total_requests:\n",
    "        total_requests += 1\n",
    "        new_responses = llm.call_api(prompt, n=vote_times)\n",
    "        valid = [r for r in new_responses if extract_final_type(r) != \"UNKNOWN\"]\n",
    "        all_valid_responses.extend(valid)\n",
    "        print(f\"🔄 Collected {len(all_valid_responses)} valid responses...\")\n",
    "    return all_valid_responses[:n_valid]\n",
    "\n",
    "# === Main function: CoT-SC majority voting ===\n",
    "def evaluate_with_voting(input_file, output_file, llm: LLM_API_Wrapper, vote_times=40, n_valid=20):\n",
    "    with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        all_inputs = json.load(f)\n",
    "\n",
    "    outputs, correct, total = [], 0, 0\n",
    "    type_correct, type_total = defaultdict(int), defaultdict(int)\n",
    "\n",
    "    for item in tqdm(all_inputs, desc=\"🔁 Running CoT-SC voting\"):\n",
    "        prompt = build_prompt(item[\"query_text\"], item[\"topk_cases\"])\n",
    "        responses = robust_llm_call(llm, prompt, vote_times, n_valid)\n",
    "        if responses:\n",
    "            predicted_types = [extract_final_type(r) for r in responses]\n",
    "            majority_vote = Counter(predicted_types).most_common(1)[0][0]\n",
    "            ground_truth = item.get(\"type\", \"UNKNOWN\").upper()\n",
    "            is_correct = majority_vote == ground_truth\n",
    "            outputs.append({\n",
    "                \"query_text\": item[\"query_text\"],\n",
    "                \"ground_truth\": ground_truth,\n",
    "                \"final_prediction\": majority_vote,\n",
    "                \"correct\": is_correct,\n",
    "                \"llm_responses\": responses\n",
    "            })\n",
    "            if ground_truth != \"UNKNOWN\":\n",
    "                total += 1; type_total[ground_truth]+=1\n",
    "                if is_correct: correct+=1; type_correct[ground_truth]+=1\n",
    "        else:\n",
    "            outputs.append({\"query_text\": item[\"query_text\"],\"final_prediction\":\"ERROR\",\"correct\":False})\n",
    "    json.dump(outputs, open(output_file,\"w\",encoding=\"utf-8\"),ensure_ascii=False,indent=2)\n",
    "    print(f\"✅ Voting finished. Results saved to {output_file}\")\n",
    "    if total>0:\n",
    "        print(f\"🎯 Accuracy: {correct}/{total} = {correct/total:.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b69a786",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import json\n",
    "\n",
    "with open(\"test对应的原始数据.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "cnt = Counter([d.get(\"type\", \"UNKNOWN\").upper() for d in data])\n",
    "print(cnt)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e56b5493",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "INPUT_FILE  = \"QwenBase_test_topk3.json\"     # 输入文件\n",
    "OUTPUT_FILE = \"QwenBase_test_topk3_clip32.json\"  # 输出文件\n",
    "MAX_PER_TYPE = 32\n",
    "\n",
    "with open(INPUT_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "buckets = defaultdict(list)\n",
    "for item in data:\n",
    "    t = item.get(\"type\", \"UNKNOWN\").upper()\n",
    "    if len(buckets[t]) < MAX_PER_TYPE:\n",
    "        buckets[t].append(item)\n",
    "\n",
    "# 拼回列表\n",
    "clipped = []\n",
    "for t, items in buckets.items():\n",
    "    clipped.extend(items)\n",
    "    print(f\"{t}: {len(items)} kept\")\n",
    "\n",
    "with open(OUTPUT_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(clipped, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ Done! Saved to {OUTPUT_FILE}, total {len(clipped)} items\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9980aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = LLM_API_Wrapper(\n",
    "    model=\"gpt-4o-mini\",        # or your deployed model name\n",
    "    api_key=\"\" # replace with your actual key\n",
    ")\n",
    "\n",
    "evaluate_with_voting(\n",
    "    input_file=\"B_test_20_topk_clip50.json\",         # your TopK retrieved file\n",
    "    output_file=\"B_test_20_clip50_vote_results.json\",# output file\n",
    "    llm=llm,\n",
    "    vote_times=5,   # how many responses per request\n",
    "    n_valid=5       # how many valid responses needed\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd42afed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from collections import defaultdict\n",
    "\n",
    "def evaluate_mbti_accuracy(result_file):\n",
    "    with open(result_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    # 四个维度统计\n",
    "    dim_names = [\"IE\", \"NS\", \"TF\", \"JP\"]\n",
    "    dim_correct = [0, 0, 0, 0]\n",
    "    dim_total = [0, 0, 0, 0]\n",
    "\n",
    "    # 完整四字母\n",
    "    total, correct = 0, 0\n",
    "\n",
    "    # 每种类型的统计\n",
    "    type_total = defaultdict(int)\n",
    "    type_correct = defaultdict(int)\n",
    "\n",
    "    for item in data:\n",
    "        gt = item[\"ground_truth\"].upper()\n",
    "        pred = item[\"final_prediction\"].upper()\n",
    "\n",
    "        if len(gt) == 4 and len(pred) == 4:\n",
    "            total += 1\n",
    "            type_total[gt] += 1\n",
    "\n",
    "            if gt == pred:\n",
    "                correct += 1\n",
    "                type_correct[gt] += 1\n",
    "\n",
    "            # 四个维度逐位比较\n",
    "            for i in range(4):\n",
    "                dim_total[i] += 1\n",
    "                if gt[i] == pred[i]:\n",
    "                    dim_correct[i] += 1\n",
    "\n",
    "    # 输出结果\n",
    "    print(\"=== 四个维度准确率 ===\")\n",
    "    for name, c, t in zip(dim_names, dim_correct, dim_total):\n",
    "        acc = c / t if t > 0 else 0\n",
    "        print(f\"{name}: {c}/{t} = {acc:.4f}\")\n",
    "\n",
    "    print(\"\\n=== 四字母完全匹配准确率 ===\")\n",
    "    print(f\"Overall: {correct}/{total} = {correct/total:.4f}\")\n",
    "\n",
    "    print(\"\\n=== 每个类型的准确率 ===\")\n",
    "    all_types = [\n",
    "        \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\n",
    "        \"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "        \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\n",
    "        \"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "    ]\n",
    "    for t in all_types:\n",
    "        c, tt = type_correct[t], type_total[t]\n",
    "        acc = c / tt if tt > 0 else 0\n",
    "        print(f\"{t}: {c}/{tt} = {acc:.4f}\")\n",
    "\n",
    "\n",
    "# 调用\n",
    "evaluate_mbti_accuracy(\"B_test_20_clip50_vote_results.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2481f4f-11ff-4ba4-8cf7-6448f668d17f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict, Counter\n",
    "import json\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import requests\n",
    "\n",
    "# === 提取最终 MBTI 类型 ===\n",
    "def extract_final_type(text):\n",
    "    # 兼容“Final Type”和“最终类型”两种写法\n",
    "    match = re.search(r'(Final Type|最终类型)[:：]?\\s*([IEie][NSns][FTft][JPjp])\\b', text)\n",
    "    if match:\n",
    "        return match.group(2).upper()\n",
    "    return \"UNKNOWN\"\n",
    "\n",
    "# === 构造 Prompt ===\n",
    "def build_prompt(user_post, topk_cases):\n",
    "    examples_text = \"\\n\\n\".join([\n",
    "        f\"[Reference Example {i+1}]\\nPost Content: {ex['post_casebank']}\\nMBTI Type: {ex['type']}\" \n",
    "        for i, ex in enumerate(topk_cases)\n",
    "    ])\n",
    "\n",
    "    prompt = f\"\"\"\n",
    "You are an expert in MBTI personality typing and linguistic style analysis.\n",
    "\n",
    "Your task is to infer the most likely MBTI type of a user's post **by analyzing their writing style** — such as tone, phrasing, language complexity, emotionality, and personality cues — and comparing it to several known examples labeled with MBTI types.\n",
    "\n",
    "**You must treat the reference examples as correct and trustworthy.**\n",
    "Do not ignore them — they provide strong clues about how different types express themselves.\n",
    "\n",
    "---\n",
    "\n",
    "[User Post]\n",
    "{user_post}\n",
    "\n",
    "[Reference Examples]\n",
    "{examples_text}\n",
    "\n",
    "---\n",
    "\n",
    "Now follow these steps:\n",
    "1.Final Type: ______\n",
    "\n",
    "2. **Analyze** the user's post: What is its tone? Abstract or concrete? Emotional or logical? Structured or scattered?\n",
    "\n",
    "3. **Compare** it with each reference example: Which ones are stylistically most similar to the user? Provide reasons.\n",
    "\n",
    "4. **Conclude** with the MBTI type that best matches the user post, based on the closest stylistic match.\n",
    "\"\"\"\n",
    "    return prompt\n",
    "\n",
    "# === 封装 LLM API 接口 ===\n",
    "class LLM_API_Wrapper:\n",
    "    def __init__(self, model, api_key):\n",
    "        self.model = model\n",
    "        self.api_key = api_key\n",
    "\n",
    "    def call_api(self, prompt_text: str, n: int = 1):\n",
    "        url = \"\"\n",
    "        payload = {\n",
    "            \"model\": self.model,\n",
    "            \"temperature\": 0.8,\n",
    "            \"n\": n,\n",
    "            \"messages\": [{\"role\": \"user\", \"content\": prompt_text}],\n",
    "            \"modalities\": [\"text\"],\n",
    "            \"response_format\": {\"type\": \"text\"},\n",
    "            \"max_completion_tokens\": 512,\n",
    "            \"stream\": False\n",
    "        }\n",
    "        headers = {\n",
    "            \"Content-Type\": \"application/json\",\n",
    "            \"Authorization\": f\"Bearer {self.api_key}\"\n",
    "        }\n",
    "        try:\n",
    "            response = requests.post(url, headers=headers, json=payload, timeout=60)\n",
    "            response.raise_for_status()\n",
    "            result = response.json()\n",
    "            if \"choices\" in result:\n",
    "                return [c[\"message\"][\"content\"] for c in result[\"choices\"]]\n",
    "            else:\n",
    "                print(\"⚠️ Missing 'choices' in response.\")\n",
    "                return []\n",
    "        except requests.exceptions.RequestException as e:\n",
    "            print(f\"❌ Request failed: {e}\")\n",
    "            return []\n",
    "\n",
    "# === 循环直到获得 n_valid 条有效响应 ===\n",
    "def robust_llm_call(llm, prompt, vote_times=40, n_valid=20, max_total_requests=20):\n",
    "    all_valid_responses = []\n",
    "    total_requests = 0\n",
    "\n",
    "    while len(all_valid_responses) < n_valid:\n",
    "        total_requests += 1\n",
    "        if total_requests > max_total_requests:\n",
    "            print(\"⚠️ 请求次数过多，跳过该样本\")\n",
    "            return []\n",
    "\n",
    "        new_responses = llm.call_api(prompt, n=vote_times)\n",
    "        if not new_responses:\n",
    "            continue\n",
    "\n",
    "        valid = [r for r in new_responses if extract_final_type(r) != \"UNKNOWN\"]\n",
    "        all_valid_responses.extend(valid)\n",
    "\n",
    "        print(f\"🔄 已累计 {len(all_valid_responses)} 条有效回答...\")\n",
    "\n",
    "    return all_valid_responses[:n_valid]\n",
    "\n",
    "# === 主函数：CoT-SC 多轮投票预测 ===\n",
    "def evaluate_with_voting(input_file, output_file, llm: LLM_API_Wrapper, vote_times=40, n_valid=20):\n",
    "    with open(input_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        all_inputs = json.load(f)\n",
    "\n",
    "    outputs = []\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    type_correct = defaultdict(int)\n",
    "    type_total = defaultdict(int)\n",
    "\n",
    "    for item in tqdm(all_inputs, desc=\"🔁 正在进行 CoT-SC 预测\"):\n",
    "        prompt = build_prompt(item[\"query_post\"], item[\"topk_cases\"])\n",
    "        responses = robust_llm_call(llm, prompt, vote_times, n_valid)\n",
    "\n",
    "        if responses:\n",
    "            predicted_types = [extract_final_type(r) for r in responses]\n",
    "            majority_vote = Counter(predicted_types).most_common(1)[0][0]\n",
    "            ground_truth = item.get(\"type\", \"UNKNOWN\").upper()\n",
    "            is_correct = majority_vote == ground_truth\n",
    "\n",
    "            outputs.append({\n",
    "                \"query_post\": item[\"query_post\"],\n",
    "                \"ground_truth\": ground_truth,\n",
    "                \"llm_responses\": responses,\n",
    "                \"predicted_types\": predicted_types,\n",
    "                \"final_prediction\": majority_vote,\n",
    "                \"correct\": is_correct,\n",
    "                \"prompt\": prompt\n",
    "            })\n",
    "\n",
    "            if ground_truth != \"UNKNOWN\":\n",
    "                total += 1\n",
    "                type_total[ground_truth] += 1\n",
    "                if is_correct:\n",
    "                    correct += 1\n",
    "                    type_correct[ground_truth] += 1\n",
    "        else:\n",
    "            outputs.append({\n",
    "                \"query_post\": item[\"query_post\"],\n",
    "                \"ground_truth\": item.get(\"type\", \"UNKNOWN\"),\n",
    "                \"llm_responses\": [],\n",
    "                \"predicted_types\": [],\n",
    "                \"final_prediction\": \"ERROR\",\n",
    "                \"correct\": False,\n",
    "                \"prompt\": prompt\n",
    "            })\n",
    "\n",
    "    with open(output_file, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(outputs, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    print(f\"\\n✅ 投票预测完成，结果保存至 {output_file}\")\n",
    "    if total > 0:\n",
    "        acc = correct / total\n",
    "        print(f\"\\n🎯 总体准确率：{acc:.4f}（{correct} / {total}）\")\n",
    "        print(\"\\n📊 各类型准确率：\")\n",
    "        for mbti_type in sorted(type_total.keys()):\n",
    "            total_count = type_total[mbti_type]\n",
    "            correct_count = type_correct[mbti_type]\n",
    "            acc_type = correct_count / total_count\n",
    "            print(f\"{mbti_type}: {acc_type:.2f} ({correct_count} / {total_count})\")\n",
    "    else:\n",
    "        print(\"⚠️ 没有包含 ground_truth 的数据，无法计算准确率。\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79ade41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from collections import Counter\n",
    "\n",
    "with open(\"retrieved_results_for_llm2.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# 提取所有的类型\n",
    "types = [item[\"type\"].upper() for item in data if \"type\" in item]\n",
    "\n",
    "# 统计数量\n",
    "type_counts = Counter(types)\n",
    "\n",
    "# 输出每种类型的样本数\n",
    "for mbti_type, count in sorted(type_counts.items()):\n",
    "    print(f\"{mbti_type}: {count} samples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a7d9dc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = LLM_API_Wrapper(model=\"gpt-4o-mini\", api_key=\"\")\n",
    "\n",
    "evaluate_with_voting(\n",
    "    input_file=\"retrieved_results_for_llm2.json\",\n",
    "    output_file=\"voting_eval_results1.json\",\n",
    "    llm=llm,\n",
    "    vote_times=5,\n",
    "    n_valid=5\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dadbcc4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = LLM_API_Wrapper(model=\"gpt-4o\", api_key=\"\")\n",
    "\n",
    "evaluate_with_voting(\n",
    "    input_file=\"retrieved_results_for_llm_cosine.json\",\n",
    "    output_file=\"voting_eval_results_cosine.json\",\n",
    "    llm=llm,\n",
    "    vote_times=10,\n",
    "    n_valid=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1392ac7-6cf7-4305-8366-fdeaff734431",
   "metadata": {},
   "outputs": [],
   "source": [
    "######  Retain\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "import faiss\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "# ===== 1. 模型加载 =====\n",
    "MODEL_NAME = \"microsoft/deberta-v3-base\"\n",
    "MODEL_PATH = \"best_fem_deberta.pt\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = AutoModel.from_pretrained(MODEL_NAME)\n",
    "state_dict = torch.load(MODEL_PATH, map_location=torch.device(\"cpu\"))\n",
    "model.load_state_dict(state_dict, strict=False)\n",
    "model.eval()\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# ===== 2. 向量提取函数 =====\n",
    "def get_embedding(text):\n",
    "    inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=512)\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        return outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()\n",
    "\n",
    "# ===== 3. 加载 LLM 输出结果（已包含预测类型）=====\n",
    "with open(\"llm_outputs_result.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    results = json.load(f)\n",
    "\n",
    "# ===== 4. 构建 Retain 样本并提取向量 =====\n",
    "new_vectors = []\n",
    "new_id_map = {}\n",
    "new_data = []\n",
    "\n",
    "for i, item in enumerate(results):\n",
    "    query_post = item[\"query_post\"]\n",
    "    posts_cleaned = item.get(\"posts_cleaned\", \"\")\n",
    "    final_pred = item.get(\"final_prediction\", \"UNKNOWN\")\n",
    "    topk_cases = item.get(\"topk_cases\", [])\n",
    "\n",
    "    post_casebank_text = \"|||\".join([case[\"post_casebank\"] for case in topk_cases])\n",
    "\n",
    "    record = {\n",
    "        \"type\": final_pred,\n",
    "        \"posts\": query_post,\n",
    "        \"posts_cleaned\": posts_cleaned,\n",
    "        \"post_casebank\": post_casebank_text\n",
    "    }\n",
    "\n",
    "    emb = get_embedding(post_casebank_text)\n",
    "    new_vectors.append(emb.astype(np.float32))\n",
    "    new_id_map[str(i)] = record\n",
    "    new_data.append(record)\n",
    "\n",
    "# ===== 5. 加载原 FAISS index + id_map 并更新 =====\n",
    "index = faiss.read_index(\"mbti_faiss.index\")\n",
    "with open(\"mbti_idmap.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    id_map = json.load(f)\n",
    "\n",
    "start_id = len(id_map)\n",
    "for i, emb in enumerate(new_vectors):\n",
    "    index.add(np.array([emb]))  # 添加新向量\n",
    "    id_map[str(start_id + i)] = new_id_map[str(i)]\n",
    "\n",
    "# ===== 6. 保存更新后的 index 和 id_map =====\n",
    "faiss.write_index(index, \"mbti_faiss.index\")\n",
    "with open(\"mbti_idmap.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(id_map, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ Retain 完成：已添加 {len(new_vectors)} 条样本至向量数据库。\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b53e554a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化你的 LLM 接口封装（替换为你的真实 API KEY）\n",
    "llm = LLM_API_Wrapper(model=\"gpt-4o\", api_key=\"\")\n",
    "\n",
    "# 运行评估：输入为你构造的 TopK 检索结果\n",
    "evaluate_with_voting(\n",
    "    input_file=\"retrieved_results_for_llm.json\",\n",
    "    output_file=\"voting_eval_results.json\",\n",
    "    llm=llm,\n",
    "    vote_times=5  # 建议先用 5 轮测试，正式评估用 20 或 40\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19708b57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化你的 LLM 接口封装（替换为你的真实 API KEY）\n",
    "llm = LLM_API_Wrapper(model=\"gpt-4o\", api_key=\"\")\n",
    "\n",
    "# 运行评估：输入为你构造的 TopK 检索结果\n",
    "evaluate_with_voting(\n",
    "    input_file=\"retrieved_results_for_llm.json\",\n",
    "    output_file=\"voting_eval_results.json\",\n",
    "    llm=llm,\n",
    "    vote_times=5  # 建议先用 5 轮测试，正式评估用 20 或 40\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f101b72d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_single_query_pipeline(raw_text: str, llm: LLM_API_Wrapper, vote_times: int = 5, do_retain=True):\n",
    "    # 1. 清洗\n",
    "    posts_cleaned = clean_text(raw_text)\n",
    "\n",
    "    # 2. 嵌入\n",
    "    query_emb = get_embedding(posts_cleaned)\n",
    "\n",
    "    # 3. 检索\n",
    "    topk_cases = retrieve_topk(query_emb, k=5)\n",
    "\n",
    "    # 4. Prompt 构造\n",
    "    prompt = build_prompt(raw_text, topk_cases)\n",
    "\n",
    "    # 5. LLM 推理\n",
    "    responses = llm.call_api(prompt, n=vote_times)\n",
    "    predicted_types = [extract_final_type(r) for r in responses]\n",
    "    majority_vote = Counter(predicted_types).most_common(1)[0][0]\n",
    "\n",
    "    result = {\n",
    "        \"query_post\": raw_text,\n",
    "        \"posts_cleaned\": posts_cleaned,\n",
    "        \"topk_cases\": topk_cases,\n",
    "        \"llm_responses\": responses,\n",
    "        \"predicted_types\": predicted_types,\n",
    "        \"final_prediction\": majority_vote,\n",
    "        \"prompt\": prompt\n",
    "    }\n",
    "\n",
    "    # 6. RETAIN（可选）\n",
    "    if do_retain:\n",
    "        post_casebank_text = \"|||\".join([case[\"post_casebank\"] for case in topk_cases])\n",
    "        emb = get_embedding(post_casebank_text)\n",
    "\n",
    "        index = faiss.read_index(\"mbti_faiss.index\")\n",
    "        with open(\"mbti_idmap.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "            id_map = json.load(f)\n",
    "\n",
    "        new_id = str(len(id_map))\n",
    "        index.add(np.array([emb.astype(np.float32)]))\n",
    "        id_map[new_id] = {\n",
    "            \"post_casebank\": post_casebank_text,\n",
    "            \"type\": majority_vote\n",
    "        }\n",
    "\n",
    "        faiss.write_index(index, \"mbti_faiss.index\")\n",
    "        with open(\"mbti_idmap.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(id_map, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab082f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_pipeline_from_json_file(json_path: str, llm: LLM_API_Wrapper, vote_times=5, retain=True):\n",
    "    with open(json_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        data = json.load(f)\n",
    "\n",
    "    all_results = []\n",
    "    print(f\"\\n📄 正在处理 JSON 文档，共 {len(data)} 条记录...\\n\")\n",
    "\n",
    "    for idx, item in enumerate(data):\n",
    "        text = item.get(\"posts\", \"\").strip()\n",
    "        if not text:\n",
    "            continue\n",
    "\n",
    "        print(f\"🔁 第 {idx+1}/{len(data)} 条用户发言处理中...\")\n",
    "\n",
    "        result = run_single_query_pipeline(text, llm, vote_times=vote_times, do_retain=retain)\n",
    "\n",
    "        # 可选：记录原始标签\n",
    "        result[\"ground_truth\"] = item.get(\"type\", \"UNKNOWN\")\n",
    "        all_results.append(result)\n",
    "\n",
    "    output_path = f\"mbti_predictions_from_{os.path.basename(json_path)}\"\n",
    "    with open(output_path, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(all_results, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "    print(f\"\\n✅ 全部完成！预测结果已保存至：{output_path}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
