{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9fb873c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Encode & retrieve: 100%|██████████| 325/325 [00:18<00:00, 17.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 向量化完成： A_test_with_embeddings.json\n",
      "✅ TopK 随机 4~6（取前 k 个）完成： A_test_top4to6_final1.json\n"
     ]
    }
   ],
   "source": [
    "# encode_and_retrieve_topk.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, faiss, numpy as np, torch, random\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
    "from peft import PeftModel\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# ========= 路径配置 =========\n",
    "# 已建好的向量库（由 casebank_*_with_embeddings.json 构建）\n",
    "FAISS_INDEX_PATH = \"casebank_A_qwen_cls_cosine.index\"\n",
    "ID_MAP_PATH      = \"casebank_A_qwen_cls_idmap.json\"\n",
    "\n",
    "# 待向量化并检索的测试集\n",
    "TEST_INPUT          = \"picked_balanced_around30.json\"\n",
    "TEST_WITH_EMB_OUT   = \"A_test_with_embeddings.json\"\n",
    "TOPK_RESULTS_OUT    = \"A_test_top4to6_final1.json\"   # 改个名字以示区分\n",
    "\n",
    "# Qwen + LoRA（与建库时完全一致）\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b-split_kaggle_ckpt\"\n",
    "\n",
    "MAX_LEN = 512\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# （可选）设定随机种子，保证复现；不需要可注释掉\n",
    "# random.seed(42)\n",
    "\n",
    "# ========= 加载索引 =========\n",
    "index = faiss.read_index(FAISS_INDEX_PATH)\n",
    "id_map = json.load(open(ID_MAP_PATH, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "# ========= 加载模型（与建库时一致）=========\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "tok.padding_side = \"right\"\n",
    "\n",
    "cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "cfg.num_labels = 16\n",
    "\n",
    "base = AutoModelForSequenceClassification.from_pretrained(\n",
    "    BASE_MODEL, config=cfg, trust_remote_code=True,\n",
    "    device_map={\"\": device}, low_cpu_mem_usage=True\n",
    ")\n",
    "model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False).eval().to(device)\n",
    "\n",
    "# ========= 工具函数 =========\n",
    "def pick_text(item: dict) -> str:\n",
    "    \"\"\"按优先级选一个可用于编码的文本字段。按你的数据结构需要可调整顺序。\"\"\"\n",
    "    for key in [\"posts_cleaned_for_embedding\", \"embed_text\", \"post_casebank\",\n",
    "                \"posts_cleaned\", \"posts\", \"text\", \"query_text\"]:\n",
    "        v = item.get(key)\n",
    "        if isinstance(v, str) and v.strip():\n",
    "            return v\n",
    "    return \"\"\n",
    "\n",
    "@torch.no_grad()\n",
    "def encode_text_to_vec(text: str) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    用 masked mean pooling 提句向量（整段文本表征）并做 L2 归一化，\n",
    "    与索引端 IndexFlatIP（内积≈余弦）一致。\n",
    "    \"\"\"\n",
    "    ins = tok(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=MAX_LEN).to(device)\n",
    "    out = model.base_model.model(**ins, output_hidden_states=True, use_cache=False, return_dict=True)\n",
    "    last_hidden = out.hidden_states[-1]                 # (B, T, H)\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1).float()  # (B, T, 1)\n",
    "\n",
    "    masked = last_hidden * mask                         # zero out pads\n",
    "    sum_hidden = masked.sum(dim=1)                      # (B, H)\n",
    "    lengths = mask.sum(dim=1).clamp(min=1.0)            # (B, 1)\n",
    "    vec = sum_hidden / lengths                          # (B, H)\n",
    "\n",
    "    vec = F.normalize(vec, p=2, dim=1)                  # L2 归一化\n",
    "    return vec.squeeze(0).float().cpu().numpy()         # (H,)\n",
    "\n",
    "def search_topk(vec: np.ndarray, k: int):\n",
    "    D, I = index.search(vec.reshape(1, -1), k)\n",
    "    return [{\"score\": float(s), **id_map[str(int(i))]} for s, i in zip(D[0], I[0])]\n",
    "\n",
    "def filter_self_hits(query_text: str, hits: list, k: int):\n",
    "    \"\"\"可选：过滤与查询文本完全相同的返回，避免“命中自己”\"\"\"\n",
    "    def norm(s): return \" \".join((s or \"\").lower().split())\n",
    "    qt = norm(query_text)\n",
    "    keep = [h for h in hits if norm(h.get(\"post_casebank\", \"\")) != qt]\n",
    "    return keep[:k]\n",
    "\n",
    "# ========= 读取测试集、编码、检索 =========\n",
    "test = json.load(open(TEST_INPUT, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "with_emb, results = [], []\n",
    "for it in tqdm(test, desc=\"Encode & retrieve\"):\n",
    "    text = pick_text(it)\n",
    "    if not text:\n",
    "        # 没有可用文本就记空\n",
    "        it_out = dict(it); it_out[\"embedding\"] = None\n",
    "        with_emb.append(it_out)\n",
    "        results.append({\"type\": it.get(\"type\",\"\"), \"query_text\": \"\", \"topk_cases\": []})\n",
    "        continue\n",
    "\n",
    "    emb = encode_text_to_vec(text)              # (H,)\n",
    "    it_out = dict(it); it_out[\"embedding\"] = emb.tolist()\n",
    "    with_emb.append(it_out)\n",
    "\n",
    "    # —— 这里随机决定 k = 4 / 5 / 6，并取前 k 个最相似 —— #\n",
    "    k_dynamic = random.randint(4, 6)\n",
    "    topk = search_topk(emb, k=k_dynamic)\n",
    "    topk = filter_self_hits(text, topk, k_dynamic)   # 不担心命中自己可注释掉\n",
    "\n",
    "    results.append({\n",
    "        \"type\": it.get(\"type\",\"\"),\n",
    "        \"query_text\": text,\n",
    "        \"topk_cases\": topk\n",
    "    })\n",
    "\n",
    "# ========= 落盘 =========\n",
    "json.dump(with_emb, open(TEST_WITH_EMB_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "json.dump(results, open(TOPK_RESULTS_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量化完成：\", TEST_WITH_EMB_OUT)\n",
    "print(\"✅ TopK 随机 4~6（取前 k 个）完成：\", TOPK_RESULTS_OUT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2d88d2e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Encode & retrieve (Random-proto GLOBAL): 100%|██████████| 325/325 [00:16<00:00, 19.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 向量化完成： A_test_with_embeddings.json\n",
      "✅ Random-proto (全局随机抽 3 个) 完成： A_test_random_global3_final1.json\n"
     ]
    }
   ],
   "source": [
    "# encode_and_retrieve_random_proto_global.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, faiss, numpy as np, torch, random\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig\n",
    "from peft import PeftModel\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# ========= 路径配置 =========\n",
    "FAISS_INDEX_PATH = \"casebank_A_qwen_cls_cosine.index\"   # 还是需要加载一下（embedding一致）\n",
    "ID_MAP_PATH      = \"casebank_A_qwen_cls_idmap.json\"     # casebank 映射表\n",
    "\n",
    "TEST_INPUT          = \"picked_balanced_around30.json\"\n",
    "TEST_WITH_EMB_OUT   = \"A_test_with_embeddings.json\"\n",
    "TOPK_RESULTS_OUT    = \"A_test_random_global3_final1.json\"   # Random-proto (全局随机 3)\n",
    "\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "ADAPTER_DIR = \"mbti_lora_qwen1.5b-split_kaggle_ckpt\"\n",
    "\n",
    "MAX_LEN = 512\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# 可选：设置随机种子保证复现\n",
    "# random.seed(42)\n",
    "\n",
    "# ========= 加载 casebank =========\n",
    "id_map = json.load(open(ID_MAP_PATH, \"r\", encoding=\"utf-8\"))\n",
    "casebank = list(id_map.values())   # 取出所有样本，后面直接随机抽\n",
    "\n",
    "# ========= 加载模型 =========\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "tok.padding_side = \"right\"\n",
    "\n",
    "cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "cfg.num_labels = 16\n",
    "\n",
    "base = AutoModelForSequenceClassification.from_pretrained(\n",
    "    BASE_MODEL, config=cfg, trust_remote_code=True,\n",
    "    device_map={\"\": device}, low_cpu_mem_usage=True\n",
    ")\n",
    "model = PeftModel.from_pretrained(base, ADAPTER_DIR, is_trainable=False).eval().to(device)\n",
    "\n",
    "# ========= 工具函数 =========\n",
    "def pick_text(item: dict) -> str:\n",
    "    for key in [\"posts_cleaned_for_embedding\", \"embed_text\", \"post_casebank\",\n",
    "                \"posts_cleaned\", \"posts\", \"text\", \"query_text\"]:\n",
    "        v = item.get(key)\n",
    "        if isinstance(v, str) and v.strip():\n",
    "            return v\n",
    "    return \"\"\n",
    "\n",
    "@torch.no_grad()\n",
    "def encode_text_to_vec(text: str) -> np.ndarray:\n",
    "    ins = tok(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=MAX_LEN).to(device)\n",
    "    out = model.base_model.model(**ins, output_hidden_states=True, use_cache=False, return_dict=True)\n",
    "    last_hidden = out.hidden_states[-1]\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1).float()\n",
    "\n",
    "    masked = last_hidden * mask\n",
    "    sum_hidden = masked.sum(dim=1)\n",
    "    lengths = mask.sum(dim=1).clamp(min=1.0)\n",
    "    vec = sum_hidden / lengths\n",
    "\n",
    "    vec = F.normalize(vec, p=2, dim=1)\n",
    "    return vec.squeeze(0).float().cpu().numpy()\n",
    "\n",
    "# ========= 主流程 =========\n",
    "test = json.load(open(TEST_INPUT, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "with_emb, results = [], []\n",
    "for it in tqdm(test, desc=\"Encode & retrieve (Random-proto GLOBAL)\"):\n",
    "    text = pick_text(it)\n",
    "    if not text:\n",
    "        it_out = dict(it); it_out[\"embedding\"] = None\n",
    "        with_emb.append(it_out)\n",
    "        results.append({\"type\": it.get(\"type\",\"\"), \"query_text\": \"\", \"topk_cases\": []})\n",
    "        continue\n",
    "\n",
    "    emb = encode_text_to_vec(text)\n",
    "    it_out = dict(it); it_out[\"embedding\"] = emb.tolist()\n",
    "    with_emb.append(it_out)\n",
    "\n",
    "    # === 全局随机抽 3 个 case ===\n",
    "    topk = random.sample(casebank, 3) if len(casebank) >= 3 else casebank\n",
    "\n",
    "    results.append({\n",
    "        \"type\": it.get(\"type\",\"\"),\n",
    "        \"query_text\": text,\n",
    "        \"topk_cases\": topk\n",
    "    })\n",
    "\n",
    "# ========= 落盘 =========\n",
    "json.dump(with_emb, open(TEST_WITH_EMB_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "json.dump(results, open(TOPK_RESULTS_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量化完成：\", TEST_WITH_EMB_OUT)\n",
    "print(\"✅ Random-proto (全局随机抽 3 个) 完成：\", TOPK_RESULTS_OUT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f768973c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Encode casebank with Qwen-base: 100%|██████████| 426/426 [02:39<00:00,  2.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 索引已写入: casebank_qwenbase_cosine.index\n",
      "✅ id_map 已写入: casebank_qwenbase_idmap.json\n",
      "📦 向量数: 27224, 维度: 1536\n"
     ]
    }
   ],
   "source": [
    "# build_casebank_index_qwen_base.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, faiss, numpy as np, torch\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# ======= 路径配置（按你本地修改） =======\n",
    "ID_MAP_IN          = \"casebank_A_qwen_cls_idmap.json\"   # 现有的 casebank 元数据（只读）\n",
    "OUT_INDEX_PATH     = \"casebank_qwenbase_cosine.index\"   # 新索引（Qwen 基座向量）\n",
    "OUT_ID_MAP_COPY    = \"casebank_qwenbase_idmap.json\"     # 复制一份 id_map（保持一致）\n",
    "\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"              # 纯基座，不加载 LoRA\n",
    "MAX_LEN     = 512\n",
    "BATCH_SIZE  = 64\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# ======= 加载 Qwen 基座 =======\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "tok.padding_side = \"right\"\n",
    "\n",
    "model = AutoModel.from_pretrained(\n",
    "    BASE_MODEL, trust_remote_code=True,\n",
    "    torch_dtype=torch.float16 if torch.cuda.is_available() else None\n",
    ").eval().to(device)\n",
    "\n",
    "@torch.no_grad()\n",
    "def embed_texts(texts):\n",
    "    ins = tok(texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=MAX_LEN).to(device)\n",
    "    out = model(**ins, output_hidden_states=True, return_dict=True)\n",
    "    last = out.last_hidden_state                    # (B,T,H)\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1).float()\n",
    "    vec = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-6)   # mean pooling\n",
    "    vec = F.normalize(vec, p=2, dim=1)                          # L2 归一化 → 余弦\n",
    "    return vec.detach().cpu().numpy().astype(\"float32\")         # (B,H)\n",
    "\n",
    "# ======= 读取 casebank 并编码 =======\n",
    "id_map = json.load(open(ID_MAP_IN, \"r\", encoding=\"utf-8\"))\n",
    "# 保持顺序稳定\n",
    "items = [id_map[k] for k in sorted(id_map.keys(), key=lambda x: int(x))]\n",
    "\n",
    "vecs = []\n",
    "for i in tqdm(range(0, len(items), BATCH_SIZE), desc=\"Encode casebank with Qwen-base\"):\n",
    "    texts = [ (it.get(\"post_casebank\") or it.get(\"post\") or \"\") for it in items[i:i+BATCH_SIZE] ]\n",
    "    vecs.append(embed_texts(texts))\n",
    "vecs = np.vstack(vecs)  # (N,H), float32\n",
    "\n",
    "# ======= 建 Faiss 索引（cosine = L2 + IP） =======\n",
    "index = faiss.IndexFlatIP(vecs.shape[1])\n",
    "index.add(vecs)\n",
    "faiss.write_index(index, OUT_INDEX_PATH)\n",
    "\n",
    "# 复制一份 id_map 以配套新索引\n",
    "json.dump(id_map, open(OUT_ID_MAP_COPY, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ 索引已写入: {OUT_INDEX_PATH}\")\n",
    "print(f\"✅ id_map 已写入: {OUT_ID_MAP_COPY}\")\n",
    "print(f\"📦 向量数: {vecs.shape[0]}, 维度: {vecs.shape[1]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f0703242",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Encode & retrieve · Top-k (Qwen-base): 100%|██████████| 325/325 [00:06<00:00, 49.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 向量化完成： QwenBase_test_with_embeddings.json\n",
      "✅ 检索完成（Top-k (Qwen-base)）: QwenBase_test_topk3_final1.json\n"
     ]
    }
   ],
   "source": [
    "# encode_and_retrieve_topk_qwen_base.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, faiss, numpy as np, torch, random\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# ========= 路径配置 =========\n",
    "FAISS_INDEX_PATH   = \"casebank_qwenbase_cosine.index\"   # 用①生成的新索引\n",
    "ID_MAP_PATH        = \"casebank_qwenbase_idmap.json\"     # 与新索引配套的 id_map\n",
    "\n",
    "TEST_INPUT         = \"picked_balanced_around30.json\"\n",
    "TEST_WITH_EMB_OUT  = \"QwenBase_test_with_embeddings.json\"\n",
    "TOPK_RESULTS_OUT   = \"QwenBase_test_topk3_final1.json\"         # e.g., top-3\n",
    "K                  = 3                                  # 这里配你要的 k（baseline=3）\n",
    "\n",
    "# 若要 Random-proto（全库随机 3 个），把 MODE 改为 \"random_global\"\n",
    "# 若要随机 4–6（取前 k 个最相似），把 MODE 改为 \"topk_4to6\"\n",
    "MODE               = \"topk\"  # \"topk\" | \"topk_4to6\" | \"random_global\"\n",
    "\n",
    "BASE_MODEL  = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MAX_LEN     = 512\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# ========= 加载索引 & id_map =========\n",
    "index = faiss.read_index(FAISS_INDEX_PATH)\n",
    "id_map = json.load(open(ID_MAP_PATH, \"r\", encoding=\"utf-8\"))\n",
    "casebank_all = list(id_map.values())  # 给 random_global 用\n",
    "\n",
    "# ========= 加载 Qwen 基座 =========\n",
    "tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "if tok.pad_token is None:\n",
    "    tok.pad_token = tok.eos_token\n",
    "tok.padding_side = \"right\"\n",
    "\n",
    "model = AutoModel.from_pretrained(\n",
    "    BASE_MODEL, trust_remote_code=True,\n",
    "    torch_dtype=torch.float16 if torch.cuda.is_available() else None\n",
    ").eval().to(device)\n",
    "\n",
    "@torch.no_grad()\n",
    "def encode_text_to_vec(text: str) -> np.ndarray:\n",
    "    ins = tok(text, return_tensors=\"pt\", truncation=True, padding=True, max_length=MAX_LEN).to(device)\n",
    "    out = model(**ins, output_hidden_states=True, return_dict=True)\n",
    "    last = out.last_hidden_state\n",
    "    mask = ins[\"attention_mask\"].unsqueeze(-1).float()\n",
    "    vec = (last * mask).sum(1) / mask.sum(1).clamp(min=1e-6)\n",
    "    vec = F.normalize(vec, p=2, dim=1)\n",
    "    return vec.squeeze(0).detach().cpu().numpy().astype(\"float32\")\n",
    "\n",
    "def search_topk(vec: np.ndarray, k: int):\n",
    "    D, I = index.search(vec.reshape(1, -1), k)\n",
    "    return [{\"score\": float(s), **id_map[str(int(i))]} for s, i in zip(D[0], I[0])]\n",
    "\n",
    "def filter_self_hits(query_text: str, hits: list, k: int = None):\n",
    "    def norm(s): return \" \".join((s or \"\").lower().split())\n",
    "    qt = norm(query_text)\n",
    "    keep = [h for h in hits if norm(h.get(\"post_casebank\", \"\")) != qt]\n",
    "    return keep if k is None else keep[:k]\n",
    "\n",
    "# ========= 读取测试集、编码、检索 =========\n",
    "test = json.load(open(TEST_INPUT, \"r\", encoding=\"utf-8\"))\n",
    "\n",
    "with_emb, results = [], []\n",
    "desc = {\n",
    "    \"topk\": \"Top-k (Qwen-base)\",\n",
    "    \"topk_4to6\": \"Top-k random 4–6 (Qwen-base)\",\n",
    "    \"random_global\": \"Random-proto global 3 (Qwen-base)\"\n",
    "}[MODE]\n",
    "\n",
    "for it in tqdm(test, desc=f\"Encode & retrieve · {desc}\"):\n",
    "    # 取文本\n",
    "    text = None\n",
    "    for key in [\"posts_cleaned_for_embedding\", \"embed_text\", \"post_casebank\",\n",
    "                \"posts_cleaned\", \"posts\", \"text\", \"query_text\"]:\n",
    "        v = it.get(key)\n",
    "        if isinstance(v, str) and v.strip():\n",
    "            text = v\n",
    "            break\n",
    "    if not text:\n",
    "        it_out = dict(it); it_out[\"embedding\"] = None\n",
    "        with_emb.append(it_out)\n",
    "        results.append({\"type\": it.get(\"type\",\"\"), \"query_text\": \"\", \"topk_cases\": []})\n",
    "        continue\n",
    "\n",
    "    emb = encode_text_to_vec(text)\n",
    "    it_out = dict(it); it_out[\"embedding\"] = emb.tolist()\n",
    "    with_emb.append(it_out)\n",
    "\n",
    "    # 三种模式\n",
    "    if MODE == \"topk\":\n",
    "        topk = filter_self_hits(text, search_topk(emb, k=K), K)\n",
    "\n",
    "    elif MODE == \"topk_4to6\":\n",
    "        k_dynamic = random.randint(4, 6)\n",
    "        topk = filter_self_hits(text, search_topk(emb, k=k_dynamic), k_dynamic)\n",
    "\n",
    "    elif MODE == \"random_global\":\n",
    "        # 全库随机 3 个（与你之前的 Random-proto 需求一致）\n",
    "        pool = casebank_all\n",
    "        topk = random.sample(pool, 3) if len(pool) >= 3 else pool\n",
    "\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown MODE: {MODE}\")\n",
    "\n",
    "    results.append({\n",
    "        \"type\": it.get(\"type\",\"\"),\n",
    "        \"query_text\": text,\n",
    "        \"topk_cases\": topk\n",
    "    })\n",
    "\n",
    "# ========= 落盘 =========\n",
    "json.dump(with_emb, open(TEST_WITH_EMB_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "json.dump(results, open(TOPK_RESULTS_OUT, \"w\", encoding=\"utf-8\"), ensure_ascii=False, indent=2)\n",
    "\n",
    "print(\"✅ 向量化完成：\", TEST_WITH_EMB_OUT)\n",
    "print(f\"✅ 检索完成（{desc}）: {TOPK_RESULTS_OUT}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
