{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "609719cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "/tmp/ipykernel_2377910/2718948663.py:165: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
      "  trainer = Trainer(model=model, args=args, tokenizer=tok, data_collator=collator)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Cannot handle batch sizes > 1 if no padding token is defined.",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mValueError\u001b[39m                                Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 188\u001b[39m\n\u001b[32m    185\u001b[39m     plot_confusion_and_roc(test_true, test_probs, MBTI_16, OUTPUT_DIR, tag=\u001b[33m\"\u001b[39m\u001b[33mnew_test\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    187\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[34m__name__\u001b[39m == \u001b[33m\"\u001b[39m\u001b[33m__main__\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m188\u001b[39m     \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 168\u001b[39m, in \u001b[36mmain\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m    165\u001b[39m trainer = Trainer(model=model, args=args, tokenizer=tok, data_collator=collator)\n\u001b[32m    167\u001b[39m \u001b[38;5;66;03m# 验证集评测\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m168\u001b[39m val_out   = \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mval_ds\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    169\u001b[39m val_logits = val_out.predictions[\u001b[32m0\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(val_out.predictions, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)) \u001b[38;5;28;01melse\u001b[39;00m val_out.predictions\n\u001b[32m    170\u001b[39m val_probs = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-\u001b[32m1\u001b[39m).cpu().numpy()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:4332\u001b[39m, in \u001b[36mTrainer.predict\u001b[39m\u001b[34m(self, test_dataset, ignore_keys, metric_key_prefix)\u001b[39m\n\u001b[32m   4329\u001b[39m start_time = time.time()\n\u001b[32m   4331\u001b[39m eval_loop = \u001b[38;5;28mself\u001b[39m.prediction_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.args.use_legacy_prediction_loop \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.evaluation_loop\n\u001b[32m-> \u001b[39m\u001b[32m4332\u001b[39m output = \u001b[43meval_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m   4333\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtest_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdescription\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mPrediction\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmetric_key_prefix\u001b[49m\n\u001b[32m   4334\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   4335\u001b[39m total_batch_size = \u001b[38;5;28mself\u001b[39m.args.eval_batch_size * \u001b[38;5;28mself\u001b[39m.args.world_size\n\u001b[32m   4336\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmetric_key_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m_jit_compilation_time\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m output.metrics:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:4449\u001b[39m, in \u001b[36mTrainer.evaluation_loop\u001b[39m\u001b[34m(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)\u001b[39m\n\u001b[32m   4446\u001b[39m         batch_size = observed_batch_size\n\u001b[32m   4448\u001b[39m \u001b[38;5;66;03m# Prediction step\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m4449\u001b[39m losses, logits, labels = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mprediction_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprediction_loss_only\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   4450\u001b[39m main_input_name = \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m.model, \u001b[33m\"\u001b[39m\u001b[33mmain_input_name\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33minput_ids\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m   4451\u001b[39m inputs_decode = (\n\u001b[32m   4452\u001b[39m     \u001b[38;5;28mself\u001b[39m._prepare_input(inputs[main_input_name]) \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33minputs\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m args.include_for_metrics \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   4453\u001b[39m )\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:4665\u001b[39m, in \u001b[36mTrainer.prediction_step\u001b[39m\u001b[34m(self, model, inputs, prediction_loss_only, ignore_keys)\u001b[39m\n\u001b[32m   4663\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_labels \u001b[38;5;129;01mor\u001b[39;00m loss_without_labels:\n\u001b[32m   4664\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m.compute_loss_context_manager():\n\u001b[32m-> \u001b[39m\u001b[32m4665\u001b[39m         loss, outputs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_outputs\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m   4666\u001b[39m     loss = loss.detach().mean()\n\u001b[32m   4668\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(outputs, \u001b[38;5;28mdict\u001b[39m):\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/trainer.py:3884\u001b[39m, in \u001b[36mTrainer.compute_loss\u001b[39m\u001b[34m(self, model, inputs, return_outputs, num_items_in_batch)\u001b[39m\n\u001b[32m   3882\u001b[39m         kwargs[\u001b[33m\"\u001b[39m\u001b[33mnum_items_in_batch\u001b[39m\u001b[33m\"\u001b[39m] = num_items_in_batch\n\u001b[32m   3883\u001b[39m     inputs = {**inputs, **kwargs}\n\u001b[32m-> \u001b[39m\u001b[32m3884\u001b[39m outputs = \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   3885\u001b[39m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[32m   3886\u001b[39m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[32m   3887\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.args.past_index >= \u001b[32m0\u001b[39m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1749\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1760\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1761\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/utils/generic.py:959\u001b[39m, in \u001b[36mcan_return_tuple.<locals>.wrapper\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m    957\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m return_dict_passed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    958\u001b[39m     return_dict = return_dict_passed\n\u001b[32m--> \u001b[39m\u001b[32m959\u001b[39m output = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    960\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[32m    961\u001b[39m     output = output.to_tuple()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/site-packages/transformers/modeling_layers.py:141\u001b[39m, in \u001b[36mGenericForSequenceClassification.forward\u001b[39m\u001b[34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, **kwargs)\u001b[39m\n\u001b[32m    138\u001b[39m     batch_size = inputs_embeds.shape[\u001b[32m0\u001b[39m]\n\u001b[32m    140\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.pad_token_id \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m batch_size != \u001b[32m1\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m141\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mCannot handle batch sizes > 1 if no padding token is defined.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    142\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.config.pad_token_id \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    143\u001b[39m     last_non_pad_token = -\u001b[32m1\u001b[39m\n",
      "\u001b[31mValueError\u001b[39m: Cannot handle batch sizes > 1 if no padding token is defined."
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import os, json, torch\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from typing import Dict, Any\n",
    "from transformers import (\n",
    "    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,\n",
    "    Trainer, TrainingArguments, DataCollatorWithPadding\n",
    ")\n",
    "from peft import PeftModel\n",
    "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc, accuracy_score\n",
    "from sklearn.preprocessing import label_binarize\n",
    "\n",
    "# ================== 配置 ==================\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "MODEL_DIR    = \"qwen-test-on-pandora_new\"                # 你的模型主目录\n",
    "ADAPTER_DIR  = os.path.join(MODEL_DIR, \"lora_adapter\")   # LoRA 适配器路径\n",
    "\n",
    "VAL_PATH     = \"val.json\"     # 新的验证集路径（你改这里）\n",
    "TEST_PATH    = \"test.json\"      # 新的测试集路径（你改这里）\n",
    "\n",
    "OUTPUT_DIR   = \"eval_new_val_test\"  # 输出目录（混淆矩阵、ROC）\n",
    "\n",
    "MAX_LEN      = 384\n",
    "BATCH_SIZE   = 16\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# ================== 工具函数 ==================\n",
    "def truncate_to_budget(tok, text, budget):\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "BUDGET = {\"posts_cleaned\": 192, \"semantic_view\": 64, \"sentiment_view\": 32, \"linguistic_view\": 24}\n",
    "\n",
    "def build_input(item: Dict[str, Any], tok) -> str:\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "def plot_confusion_and_roc(y_true, y_prob, class_names, out_dir, tag=\"eval\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    y_pred = np.argmax(y_prob, axis=-1)\n",
    "    # Confusion Matrix\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig_cm, ax_cm = plt.subplots(figsize=(8, 8), dpi=150)\n",
    "    disp.plot(ax=ax_cm, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax_cm.set_title(f\"Confusion Matrix ({tag})\")\n",
    "    fig_cm.tight_layout()\n",
    "    fig_cm.savefig(os.path.join(out_dir, f\"{tag}_confusion_matrix.png\"))\n",
    "    plt.close(fig_cm)\n",
    "\n",
    "    # ROC\n",
    "    Y_true_bin = label_binarize(y_true, classes=list(range(len(class_names))))\n",
    "    fpr, tpr, roc_auc = {}, {}, {}\n",
    "    valid = []\n",
    "    for i in range(len(class_names)):\n",
    "        if Y_true_bin[:, i].sum() == 0:\n",
    "            continue\n",
    "        fpr[i], tpr[i], _ = roc_curve(Y_true_bin[:, i], y_prob[:, i])\n",
    "        roc_auc[i] = auc(fpr[i], tpr[i])\n",
    "        valid.append(i)\n",
    "    if len(valid) >= 2:\n",
    "        fpr[\"micro\"], tpr[\"micro\"], _ = roc_curve(\n",
    "            Y_true_bin[:, valid].ravel(), y_prob[:, valid].ravel()\n",
    "        )\n",
    "        roc_auc[\"micro\"] = auc(fpr[\"micro\"], tpr[\"micro\"])\n",
    "        all_fpr = np.unique(np.concatenate([fpr[i] for i in valid]))\n",
    "        mean_tpr = np.zeros_like(all_fpr)\n",
    "        for i in valid:\n",
    "            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])\n",
    "        mean_tpr /= len(valid)\n",
    "        fpr[\"macro\"] = all_fpr; tpr[\"macro\"] = mean_tpr\n",
    "        roc_auc[\"macro\"] = auc(fpr[\"macro\"], tpr[\"macro\"])\n",
    "\n",
    "        fig_roc, ax_roc = plt.subplots(figsize=(7, 7), dpi=150)\n",
    "        ax_roc.plot(fpr[\"micro\"], tpr[\"micro\"],\n",
    "                    label=f\"micro-average ROC (AUC = {roc_auc['micro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot(fpr[\"macro\"], tpr[\"macro\"],\n",
    "                    label=f\"macro-average ROC (AUC = {roc_auc['macro']:.3f})\", linewidth=2)\n",
    "        ax_roc.plot([0, 1], [0, 1], \"k--\", linewidth=1)\n",
    "        ax_roc.set_xlim([0.0, 1.0]); ax_roc.set_ylim([0.0, 1.05])\n",
    "        ax_roc.set_xlabel(\"False Positive Rate\"); ax_roc.set_ylabel(\"True Positive Rate\")\n",
    "        ax_roc.set_title(f\"Multiclass ROC ({tag})\")\n",
    "        ax_roc.legend(loc=\"lower right\")\n",
    "        fig_roc.tight_layout()\n",
    "        fig_roc.savefig(os.path.join(out_dir, f\"{tag}_roc_micro_macro.png\"))\n",
    "        plt.close(fig_roc)\n",
    "\n",
    "# ================== 主流程 ==================\n",
    "def main():\n",
    "    # Tokenizer\n",
    "    tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "    if tok.pad_token is None:\n",
    "        tok.pad_token = tok.eos_token\n",
    "    tok.padding_side = \"right\"\n",
    "\n",
    "    # 加载基座 + LoRA\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "    base_cfg.num_labels = len(MBTI_16)\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL, config=base_cfg, device_map=\"auto\", trust_remote_code=True\n",
    "    )\n",
    "    model = PeftModel.from_pretrained(base, ADAPTER_DIR)\n",
    "    model = model.merge_and_unload()  # 合并权重，推理更快\n",
    "    model.eval()\n",
    "\n",
    "    # 构建数据集\n",
    "    val_rows  = load_rows(VAL_PATH)\n",
    "    test_rows = load_rows(TEST_PATH)\n",
    "    val_ds  = MBTIDataset(val_rows,  tok, max_len=MAX_LEN)\n",
    "    test_ds = MBTIDataset(test_rows, tok, max_len=MAX_LEN)\n",
    "    collator = DataCollatorWithPadding(tok, pad_to_multiple_of=8)\n",
    "\n",
    "    args = TrainingArguments(\n",
    "        output_dir=OUTPUT_DIR,\n",
    "        per_device_eval_batch_size=BATCH_SIZE,\n",
    "        report_to=\"none\"\n",
    "    )\n",
    "    trainer = Trainer(model=model, args=args, tokenizer=tok, data_collator=collator)\n",
    "\n",
    "    # 验证集评测\n",
    "    val_out   = trainer.predict(val_ds)\n",
    "    val_logits = val_out.predictions[0] if isinstance(val_out.predictions, (list, tuple)) else val_out.predictions\n",
    "    val_probs = F.softmax(torch.tensor(val_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    val_true  = val_out.label_ids\n",
    "    val_pred  = np.argmax(val_probs, axis=-1)\n",
    "    val_acc   = accuracy_score(val_true, val_pred)\n",
    "    print(f\"\\n=== Validation Accuracy: {val_acc:.4f}\")\n",
    "    plot_confusion_and_roc(val_true, val_probs, MBTI_16, OUTPUT_DIR, tag=\"new_val\")\n",
    "\n",
    "    # 测试集评测\n",
    "    test_out   = trainer.predict(test_ds)\n",
    "    test_logits = test_out.predictions[0] if isinstance(test_out.predictions, (list, tuple)) else test_out.predictions\n",
    "    test_probs = F.softmax(torch.tensor(test_logits, dtype=torch.float32), dim=-1).cpu().numpy()\n",
    "    test_true  = test_out.label_ids\n",
    "    test_pred  = np.argmax(test_probs, axis=-1)\n",
    "    test_acc   = accuracy_score(test_true, test_pred)\n",
    "    print(f\"=== Test Accuracy: {test_acc:.4f}\")\n",
    "    plot_confusion_and_roc(test_true, test_probs, MBTI_16, OUTPUT_DIR, tag=\"new_test\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1abf43b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen2.5-1.5B-Instruct and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== Test Accuracy on test对应的原始数据.json: 0.8054\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n",
      "\u001b[1;31m请查看单元格中的代码，以确定故障的可能原因。\n",
      "\u001b[1;31m单击<a href='https://aka.ms/vscodeJupyterKernelCrash'>此处</a>了解详细信息。\n",
      "\u001b[1;31m有关更多详细信息，请查看 Jupyter <a href='command:jupyter.viewOutput'>log</a>。"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import os, json, torch\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, DataCollatorWithPadding\n",
    "from peft import PeftModel\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay\n",
    "\n",
    "# ===== 路径配置 =====\n",
    "BASE_MODEL   = \"Qwen/Qwen2.5-1.5B-Instruct\"     # 训练用的基座\n",
    "MODEL_DIR    = \"qwen-test-on-pandora_new\"       # 你的模型目录\n",
    "ADAPTER_DIR  = os.path.join(MODEL_DIR, \"lora_adapter\")  # LoRA adapter\n",
    "TEST_PATH    = \"test对应的原始数据.json\"     # 你要评测的新测试集\n",
    "OUTPUT_DIR   = \"eval_only_test\"\n",
    "\n",
    "MAX_LEN    = 384\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "MBTI_16 = [\n",
    "    \"INTJ\",\"INTP\",\"ENTJ\",\"ENTP\",\"INFJ\",\"INFP\",\"ENFJ\",\"ENFP\",\n",
    "    \"ISTJ\",\"ISFJ\",\"ESTJ\",\"ESFJ\",\"ISTP\",\"ISFP\",\"ESTP\",\"ESFP\"\n",
    "]\n",
    "MBTI2ID = {t:i for i,t in enumerate(MBTI_16)}\n",
    "\n",
    "# ===== 数据处理函数 =====\n",
    "def truncate_to_budget(tok, text, budget):\n",
    "    enc = tok(text or \"\", add_special_tokens=False)\n",
    "    ids = enc[\"input_ids\"][: budget]\n",
    "    return tok.decode(ids)\n",
    "\n",
    "BUDGET = {\"posts_cleaned\":192,\"semantic_view\":64,\"sentiment_view\":32,\"linguistic_view\":24}\n",
    "\n",
    "def build_input(item, tok):\n",
    "    p_raw = item.get(\"posts_cleaned\", item.get(\"posts\",\"\")) or item.get(\"text\",\"\") or \"\"\n",
    "    sem   = item.get(\"semantic_view\",\"\")  or \"\"\n",
    "    sen   = item.get(\"sentiment_view\",\"\") or \"\"\n",
    "    lin   = item.get(\"linguistic_view\",\"\") or \"\"\n",
    "\n",
    "    p   = truncate_to_budget(tok, p_raw, BUDGET[\"posts_cleaned\"])\n",
    "    sem = truncate_to_budget(tok, sem,   BUDGET[\"semantic_view\"])\n",
    "    sen = truncate_to_budget(tok, sen,   BUDGET[\"sentiment_view\"])\n",
    "    lin = truncate_to_budget(tok, lin,   BUDGET[\"linguistic_view\"])\n",
    "\n",
    "    return (\n",
    "        f\"[POSTS]\\n{p}\\n[SEMANTIC]\\n{sem}\\n[SENTIMENT]\\n{sen}\\n[LINGUISTIC]\\n{lin}\\n\"\n",
    "        f\"[TASK] Predict MBTI type among {', '.join(MBTI_16)}.\"\n",
    "    )\n",
    "\n",
    "def load_rows(path: str):\n",
    "    with open(path, \"r\", encoding=\"utf-8\") as f:\n",
    "        rows = json.load(f)\n",
    "    clean = []\n",
    "    for r in rows:\n",
    "        t = (r.get(\"type\") or r.get(\"label\") or \"\").upper().strip()\n",
    "        if t in MBTI2ID:\n",
    "            r[\"type\"] = t\n",
    "            clean.append(r)\n",
    "    return clean\n",
    "\n",
    "class MBTIDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, rows, tokenizer, max_len=512):\n",
    "        self.rows = rows\n",
    "        self.tok  = tokenizer\n",
    "        self.max_len = max_len\n",
    "    def __len__(self): return len(self.rows)\n",
    "    def __getitem__(self, idx):\n",
    "        it   = self.rows[idx]\n",
    "        text = build_input(it, self.tok)\n",
    "        y    = MBTI2ID[it[\"type\"]]\n",
    "        enc  = self.tok(text, truncation=True, max_length=self.max_len)\n",
    "        return {\"input_ids\": enc[\"input_ids\"], \"attention_mask\": enc[\"attention_mask\"], \"labels\": y}\n",
    "\n",
    "# ===== 混淆矩阵可视化 =====\n",
    "def plot_confusion(y_true, y_pred, class_names, out_dir, tag=\"test\"):\n",
    "    os.makedirs(out_dir, exist_ok=True)\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(class_names))))\n",
    "    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)\n",
    "    fig, ax = plt.subplots(figsize=(8,8), dpi=150)\n",
    "    disp.plot(ax=ax, xticks_rotation=45, cmap=\"Blues\", colorbar=False)\n",
    "    ax.set_title(f\"Confusion Matrix ({tag})\")\n",
    "    fig.tight_layout()\n",
    "    fig.savefig(os.path.join(out_dir, f\"{tag}_confusion_matrix.png\"))\n",
    "    plt.close(fig)\n",
    "\n",
    "# ===== 主流程 =====\n",
    "def main():\n",
    "    # Tokenizer\n",
    "    tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "    if tok.pad_token is None:\n",
    "        tok.pad_token = tok.eos_token\n",
    "    tok.padding_side = \"right\"\n",
    "\n",
    "    # 加载基座 + LoRA\n",
    "    base_cfg = AutoConfig.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "    base_cfg.num_labels = len(MBTI_16)\n",
    "    base = AutoModelForSequenceClassification.from_pretrained(\n",
    "        BASE_MODEL, config=base_cfg, device_map=\"auto\", trust_remote_code=True\n",
    "    )\n",
    "    model = PeftModel.from_pretrained(base, ADAPTER_DIR)\n",
    "    model = model.merge_and_unload()\n",
    "    model.config.pad_token_id = tok.pad_token_id  # 关键修复\n",
    "    model.eval()\n",
    "\n",
    "    # 数据\n",
    "    test_rows = load_rows(TEST_PATH)\n",
    "    test_ds   = MBTIDataset(test_rows, tok, max_len=MAX_LEN)\n",
    "\n",
    "    # 推理\n",
    "    all_logits, all_labels = [], []\n",
    "    for batch in torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,\n",
    "                                             collate_fn=DataCollatorWithPadding(tok, pad_to_multiple_of=8)):\n",
    "        batch = {k:v.to(model.device) for k,v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            out = model(**batch).logits\n",
    "        all_logits.append(out.cpu())\n",
    "        all_labels.append(batch[\"labels\"].cpu())\n",
    "    logits = torch.cat(all_logits).numpy()\n",
    "    y_true = torch.cat(all_labels).numpy()\n",
    "    y_pred = np.argmax(logits, axis=-1)\n",
    "\n",
    "    # 结果\n",
    "    acc = accuracy_score(y_true, y_pred)\n",
    "    print(f\"=== Test Accuracy on {TEST_PATH}: {acc:.4f}\")\n",
    "    plot_confusion(y_true, y_pred, MBTI_16, OUTPUT_DIR, tag=\"new_test\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6ecaa71b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 已写出：casebank_A_train_80_with_ids.json  和  casebank_text2id.json（文本->id 映射）\n"
     ]
    }
   ],
   "source": [
    "# add_case_ids.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, os, hashlib\n",
    "\n",
    "CASEBANK_IN  = \"casebank_A_train_80_with_embeddings.json\"   # 你的 casebank embedding 文件\n",
    "CASEBANK_OUT = \"casebank_A_train_80_with_ids.json\"          # 输出：补了 case_id 的文件\n",
    "\n",
    "def canon_key(text: str) -> str:\n",
    "    \"\"\"把文本规整成稳定键（去空白/大小写）再哈希，避免微小差异导致匹配失败。\"\"\"\n",
    "    t = (text or \"\").strip().lower()\n",
    "    t = \" \".join(t.split())          # 折叠多空格/换行\n",
    "    return hashlib.md5(t.encode(\"utf-8\")).hexdigest()\n",
    "\n",
    "with open(CASEBANK_IN, \"r\", encoding=\"utf-8\") as f:\n",
    "    bank = json.load(f)\n",
    "\n",
    "for i, item in enumerate(bank):\n",
    "    item[\"case_id\"] = i  # 用顺序 index 当稳定 id\n",
    "\n",
    "# 额外生成：文本->id 的查找表，方便下一步匹配\n",
    "id_lookup = {}\n",
    "for it in bank:\n",
    "    key = canon_key(it.get(\"post_casebank\") or it.get(\"embed_text\") or it.get(\"posts_cleaned\") or it.get(\"posts\") or \"\")\n",
    "    if key:\n",
    "        id_lookup[key] = it[\"case_id\"]\n",
    "\n",
    "with open(CASEBANK_OUT, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(bank, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "with open(\"casebank_text2id.json\", \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(id_lookup, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ 已写出：{CASEBANK_OUT}  和  casebank_text2id.json（文本->id 映射）\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cd5a9848",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CaseBank] 条数=27224；新增case_id=27224；可索引文本=27223；空文本=1\n",
      "[Index] 唯一前缀键数=27120\n",
      "[覆盖率] matched=10212, ambig=0, miss=0, total=10212 → 100.00%\n",
      "[匹配路径统计]（用于诊断）\n",
      "                ok@start: 10082\n",
      "                  ok@120: 130\n",
      "✅ 输出：case_usage_counts_prefix.csv / test_to_cases_with_ids_prefix.jsonl\n"
     ]
    }
   ],
   "source": [
    "# count_case_usage_prefix_exact.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, csv, re\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "# ===== 路径（改成你的）=====\n",
    "CASEBANK_FILE = \"casebank_A_train_80_with_embeddings.json\"  # 含 posts_cleaned 的 casebank\n",
    "TOPK_FILE     = \"A_test_top3.json\"                          # topk=3 结果文件\n",
    "OUT_COUNTS    = \"case_usage_counts_prefix.csv\"\n",
    "OUT_PAIRS     = \"test_to_cases_with_ids_prefix.jsonl\"\n",
    "\n",
    "# ===== 字段（按你的数据）=====\n",
    "BANK_TEXT   = \"posts_cleaned\"     # casebank 的原文字段\n",
    "TOPK_LIST   = \"topk_cases\"        # topk 列表字段\n",
    "TOPK_TEXT   = \"post_casebank\"     # topk 每条候选里的原文字段\n",
    "K           = 3                   # top-k\n",
    "\n",
    "# ===== 匹配策略参数 =====\n",
    "PREFIX_START = 80     # 诊断已证实：80 能全覆盖\n",
    "PREFIX_STEP  = 40     # 不唯一时，每次增加的前缀长度\n",
    "MAX_PREFIX   = 2000   # 前缀最长检查到多少字符（防止极端长文本）\n",
    "CHOOSE_FIRST_IF_AMBIG = False  # True=在仍不唯一时取第一个；False=跳过以保证严格\n",
    "\n",
    "# --- 文本归一化：只折叠空白，不改大小写 ---\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def norm_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())\n",
    "\n",
    "# ---------- 读取 casebank，确保有 case_id ----------\n",
    "with open(CASEBANK_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    bank = json.load(f)\n",
    "\n",
    "added_ids = 0\n",
    "for i, it in enumerate(bank):\n",
    "    if \"case_id\" not in it:\n",
    "        it[\"case_id\"] = i\n",
    "        added_ids += 1\n",
    "\n",
    "# 预计算：规范化全文 & 建立“前缀->候选id列表”的索引（只建起始前缀，后续按需收缩候选）\n",
    "norm_text_by_id = {}\n",
    "prefix_index = defaultdict(list)  # key = norm_text[:PREFIX_START] -> [case_id,...]\n",
    "\n",
    "empty_cnt = 0\n",
    "for it in bank:\n",
    "    t = it.get(BANK_TEXT, \"\")\n",
    "    if not isinstance(t, str) or not t:\n",
    "        empty_cnt += 1\n",
    "        continue\n",
    "    nt = norm_space(t)\n",
    "    norm_text_by_id[it[\"case_id\"]] = nt\n",
    "    key0 = nt[:PREFIX_START]\n",
    "    prefix_index[key0].append(it[\"case_id\"])\n",
    "\n",
    "print(f\"[CaseBank] 条数={len(bank)}；新增case_id={added_ids}；可索引文本={len(norm_text_by_id)}；空文本={empty_cnt}\")\n",
    "print(f\"[Index] 唯一前缀键数={len(prefix_index)}\")\n",
    "\n",
    "# ---------- 读取 topk ----------\n",
    "with open(TOPK_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# ---------- 匹配函数：空白折叠 + 递增前缀直到唯一 ----------\n",
    "def resolve_case_id(raw_text: str):\n",
    "    if not isinstance(raw_text, str) or not raw_text:\n",
    "        return None, \"empty\"\n",
    "    nt = norm_space(raw_text)\n",
    "    pref_len = min(PREFIX_START, len(nt), MAX_PREFIX)\n",
    "    # 初始候选：用起始前缀在索引里查\n",
    "    cand = prefix_index.get(nt[:pref_len], [])\n",
    "    if not cand:\n",
    "        return None, \"no_prefix_key\"\n",
    "\n",
    "    # 若只有一个候选，直接返回\n",
    "    if len(cand) == 1:\n",
    "        return cand[0], \"ok@start\"\n",
    "\n",
    "    # 尝试递增前缀长度筛掉不匹配项\n",
    "    while pref_len < min(len(nt), MAX_PREFIX):\n",
    "        pref_len = min(pref_len + PREFIX_STEP, len(nt), MAX_PREFIX)\n",
    "        pref = nt[:pref_len]\n",
    "        cand = [cid for cid in cand if norm_text_by_id[cid].startswith(pref)]\n",
    "        if len(cand) <= 1:\n",
    "            break\n",
    "\n",
    "    if len(cand) == 1:\n",
    "        return cand[0], f\"ok@{pref_len}\"\n",
    "    else:\n",
    "        # 仍不唯一\n",
    "        if CHOOSE_FIRST_IF_AMBIG and len(cand) > 0:\n",
    "            return cand[0], f\"ambig_choose_first@{pref_len}\"\n",
    "        return None, f\"ambig_skip@{pref_len}\"\n",
    "\n",
    "# ---------- 统计 ----------\n",
    "M = len(data)\n",
    "total_slots = 0\n",
    "matched = 0\n",
    "ambig = 0\n",
    "miss  = 0\n",
    "counter = Counter()\n",
    "pairs  = []\n",
    "\n",
    "# 统计各匹配路径（便于排查）\n",
    "route_counter = Counter()\n",
    "\n",
    "for i, rec in enumerate(data):\n",
    "    ids = []\n",
    "    for c in rec.get(TOPK_LIST, [])[:K]:\n",
    "        total_slots += 1\n",
    "        cid, route = resolve_case_id(c.get(TOPK_TEXT, \"\"))\n",
    "        route_counter[route] += 1\n",
    "        if cid is None:\n",
    "            ids.append(None)\n",
    "            if route.startswith(\"ambig\"): ambig += 1\n",
    "            else: miss += 1\n",
    "        else:\n",
    "            ids.append(int(cid))\n",
    "            counter[cid] += 1\n",
    "            matched += 1\n",
    "    pairs.append({\"test_index\": i, \"topk_case_ids\": ids})\n",
    "\n",
    "cov = matched / total_slots if total_slots else 0.0\n",
    "print(f\"[覆盖率] matched={matched}, ambig={ambig}, miss={miss}, total={total_slots} → {cov:.2%}\")\n",
    "print(\"[匹配路径统计]（用于诊断）\")\n",
    "for k, v in route_counter.most_common():\n",
    "    print(f\"  {k:>22}: {v}\")\n",
    "\n",
    "# ---------- 输出 ----------\n",
    "with open(OUT_COUNTS, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
    "    w = csv.writer(f)\n",
    "    w.writerow([\"case_id\", \"hits\", \"hit_rate_per_query\", \"hit_rate_over_slots\"])\n",
    "    for cid, hits in sorted(counter.items(), key=lambda x: x[1], reverse=True):\n",
    "        w.writerow([cid, hits, hits / M if M else 0.0, hits / total_slots if total_slots else 0.0])\n",
    "\n",
    "with open(OUT_PAIRS, \"w\", encoding=\"utf-8\") as f:\n",
    "    for row in pairs:\n",
    "        f.write(json.dumps(row, ensure_ascii=False) + \"\\n\")\n",
    "\n",
    "print(f\"✅ 输出：{OUT_COUNTS} / {OUT_PAIRS}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f91cced",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 已写出：case_usage_all.csv\n",
      "样本数 M=3404, 总槽位=10212\n",
      "\n",
      "Top cases by hits:\n",
      " 1. cid=5851   hits=13    uq=13    top1=4     pos[0/1/2]=(4/5/4)\n",
      " 2. cid=6277   hits=11    uq=11    top1=7     pos[0/1/2]=(7/1/3)\n",
      " 3. cid=10021  hits=10    uq=10    top1=6     pos[0/1/2]=(6/4/0)\n",
      " 4. cid=18467  hits=9     uq=9     top1=3     pos[0/1/2]=(3/2/4)\n",
      " 5. cid=6704   hits=9     uq=9     top1=2     pos[0/1/2]=(2/1/6)\n",
      " 6. cid=1866   hits=9     uq=9     top1=2     pos[0/1/2]=(2/6/1)\n",
      " 7. cid=20808  hits=9     uq=9     top1=3     pos[0/1/2]=(3/2/4)\n",
      " 8. cid=3792   hits=9     uq=9     top1=3     pos[0/1/2]=(3/3/3)\n",
      " 9. cid=10748  hits=8     uq=8     top1=7     pos[0/1/2]=(7/0/1)\n",
      "10. cid=18716  hits=8     uq=8     top1=4     pos[0/1/2]=(4/2/2)\n"
     ]
    }
   ],
   "source": [
    "# count_all_case_usage.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, csv\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "# ======= 配置 =======\n",
    "PAIRS_FILE    = \"test_to_cases_with_ids_prefix.jsonl\"   # 或 test_to_cases_with_ids_exact.jsonl\n",
    "CASEBANK_FILE = \"casebank_A_train_80_with_embeddings.json\"  # 可选：若提供会输出 type/文本片段\n",
    "TEXT_FIELD    = \"posts_cleaned\"     # 用于展示的文本字段（仅在提供 CASEBANK_FILE 时生效）\n",
    "SHOW_TOPN     = 10                  # 终端打印前N名\n",
    "OUT_CSV       = \"case_usage_all.csv\"\n",
    "\n",
    "# ======= 读取 pairs（每行：{\"test_index\": i, \"topk_case_ids\": [cid0,cid1,cid2]}）=======\n",
    "pairs = [json.loads(l) for l in open(PAIRS_FILE, \"r\", encoding=\"utf-8\")]\n",
    "M = len(pairs)\n",
    "lists = [r[\"topk_case_ids\"] for r in pairs]\n",
    "total_slots = sum(len(x) for x in lists)\n",
    "\n",
    "# 扁平化命中（忽略 None）\n",
    "flat_ids = [cid for lst in lists for cid in lst if cid is not None]\n",
    "hits_counter = Counter(flat_ids)\n",
    "\n",
    "# 每个 case 被多少“不同样本”命中（去重）\n",
    "unique_queries_counter = Counter()\n",
    "for lst in lists:\n",
    "    uniq = {cid for cid in lst if cid is not None}\n",
    "    unique_queries_counter.update(uniq)\n",
    "\n",
    "# Top1 命中次数 & 各位置分布\n",
    "top1_counter = Counter()\n",
    "pos_counters = defaultdict(Counter)  # pos_counters[cid][pos] 计数\n",
    "for lst in lists:\n",
    "    if not lst: continue\n",
    "    for pos, cid in enumerate(lst):\n",
    "        if cid is None: continue\n",
    "        pos_counters[cid][pos] += 1\n",
    "    cid0 = lst[0]\n",
    "    if cid0 is not None:\n",
    "        top1_counter[cid0] += 1\n",
    "\n",
    "# ======= （可选）读取 casebank 做展示增强 =======\n",
    "cid2type, cid2text = {}, {}\n",
    "if CASEBANK_FILE:\n",
    "    with open(CASEBANK_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "        bank = json.load(f)\n",
    "    for i, it in enumerate(bank):\n",
    "        cid = it.get(\"case_id\", i)           # 若没写case_id，用顺序索引兜底\n",
    "        cid2type[cid] = it.get(\"type\", \"\")\n",
    "        txt = it.get(TEXT_FIELD, \"\") or it.get(\"posts\", \"\") or \"\"\n",
    "        cid2text[cid] = (txt[:80].replace(\"\\n\", \" \") if isinstance(txt, str) else \"\")\n",
    "\n",
    "# ======= 写出总表（按 hits 降序）=======\n",
    "with open(OUT_CSV, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
    "    w = csv.writer(f)\n",
    "    header = [\"case_id\",\"hits\",\"unique_queries\",\"top1_hits\",\"pos0\",\"pos1\",\"pos2\",\n",
    "              \"rate_over_slots\",\"rate_per_query\",\"top1_rate\"]\n",
    "    if CASEBANK_FILE:\n",
    "        header += [\"type\",\"text_head\"]\n",
    "    w.writerow(header)\n",
    "\n",
    "    # 所有出现过的 id（也可以选择把未出现过的 id 一并写出，命中为0）\n",
    "    all_ids = sorted(hits_counter.keys(), key=lambda x: hits_counter[x], reverse=True)\n",
    "\n",
    "    for cid in all_ids:\n",
    "        hits = hits_counter[cid]\n",
    "        uq   = unique_queries_counter.get(cid, 0)\n",
    "        t1   = top1_counter.get(cid, 0)\n",
    "        p0   = pos_counters[cid].get(0, 0)\n",
    "        p1   = pos_counters[cid].get(1, 0)\n",
    "        p2   = pos_counters[cid].get(2, 0)\n",
    "\n",
    "        rate_over_slots = hits / total_slots if total_slots else 0.0\n",
    "        rate_per_query  = uq   / M if M else 0.0\n",
    "        top1_rate       = t1   / M if M else 0.0\n",
    "\n",
    "        row = [cid, hits, uq, t1, p0, p1, p2, rate_over_slots, rate_per_query, top1_rate]\n",
    "        if CASEBANK_FILE:\n",
    "            row += [cid2type.get(cid, \"\"), cid2text.get(cid, \"\")]\n",
    "        w.writerow(row)\n",
    "\n",
    "print(f\"✅ 已写出：{OUT_CSV}\")\n",
    "print(f\"样本数 M={M}, 总槽位={total_slots}\")\n",
    "\n",
    "# ======= 终端快速查看 TopN =======\n",
    "print(\"\\nTop cases by hits:\")\n",
    "for i, cid in enumerate(sorted(hits_counter.keys(), key=lambda x: hits_counter[x], reverse=True)[:SHOW_TOPN], 1):\n",
    "    hits = hits_counter[cid]\n",
    "    uq   = unique_queries_counter.get(cid, 0)\n",
    "    t1   = top1_counter.get(cid, 0)\n",
    "    p0   = pos_counters[cid].get(0, 0)\n",
    "    p1   = pos_counters[cid].get(1, 0)\n",
    "    p2   = pos_counters[cid].get(2, 0)\n",
    "    print(f\"{i:>2}. cid={cid:<6} hits={hits:<5} uq={uq:<5} top1={t1:<5} pos[0/1/2]=({p0}/{p1}/{p2})\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4f123a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[train] 总=27224，键数=27120，重复文=103，空文本=1\n",
      "[映射topk] 覆盖率=100.00%  (matched=10212, total=10212)\n",
      "→ 写出 pairs_case_ids.jsonl\n",
      "→ 写出 case_usage_all.csv\n",
      "→ 合并完成：27223 条（缺失向量 1 条），写出 case_meta_with_hits.json\n",
      "✅ 聚类图已保存：cluster_highlight.png\n",
      "✅ 已保存带标注版本：cluster_highlight_annotated.png\n"
     ]
    }
   ],
   "source": [
    "# cluster_and_highlight_from_train_topk.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, csv, re, os\n",
    "from collections import Counter, defaultdict\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# ========= 路径（按需修改）=========\n",
    "TRAIN_FILE   = \"train.json\"                                   # 权威原文库\n",
    "TOPK_FILE    = \"A_test_top3.json\"                             # top-k 结果\n",
    "EMB_FILE     = \"casebank_A_train_80_with_embeddings.json\"     # 含 embedding（已算好的）\n",
    "OUT_PAIRS    = \"pairs_case_ids.jsonl\"                         # 每个样本命中的 case_id 列表\n",
    "OUT_COUNTS   = \"case_usage_all.csv\"                           # 每个 id 的使用统计\n",
    "OUT_MERGED   = \"case_meta_with_hits.json\"                     # id + type + text + hits + embedding\n",
    "OUT_PLOT     = \"cluster_highlight.png\"                        # 聚类散点图\n",
    "\n",
    "# ========= 字段名 =========\n",
    "TRAIN_TEXT_FIELD = \"posts_cleaned\"    # train.json 用于对齐的文本字段\n",
    "TOPK_LIST_FIELD  = \"topk_cases\"       # topk 列表字段\n",
    "TOPK_TEXT_FIELD  = \"post_casebank\"    # topk 候选里的文本字段\n",
    "TOPK_K           = 3\n",
    "\n",
    "# ========= 匹配模式（与你的诊断一致）=========\n",
    "# 'raw' 精确；'space' 仅折叠空白；'prefix80_space' 折叠空白后取前80字精确\n",
    "MATCH_MODE = \"prefix80_space\"\n",
    "\n",
    "_WS = re.compile(r\"\\s+\")\n",
    "def key_raw(s: str) -> str: return s\n",
    "def key_space(s: str) -> str: return _WS.sub(\" \", (s or \"\").strip())\n",
    "def key_prefix80_space(s: str) -> str: return _WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "KEY = {\"raw\": key_raw, \"space\": key_space, \"prefix80_space\": key_prefix80_space}[MATCH_MODE]\n",
    "\n",
    "# ========= Step 1. 用 train 建立 文本->case_id 索引 =========\n",
    "with open(TRAIN_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    train = json.load(f)\n",
    "N = len(train)\n",
    "\n",
    "text2id = {}\n",
    "id2type, id2text = {}, {}\n",
    "dup_in_train = empty_in_train = 0\n",
    "for i, it in enumerate(train):\n",
    "    t = it.get(TRAIN_TEXT_FIELD, \"\")\n",
    "    if not isinstance(t, str) or not t:\n",
    "        empty_in_train += 1\n",
    "        continue\n",
    "    k = KEY(t)\n",
    "    if k not in text2id:\n",
    "        text2id[k] = i\n",
    "    else:\n",
    "        dup_in_train += 1\n",
    "    id2type[i] = it.get(\"type\", \"\")\n",
    "    id2text[i] = t\n",
    "\n",
    "print(f\"[train] 总={N}，键数={len(text2id)}，重复文={dup_in_train}，空文本={empty_in_train}\")\n",
    "\n",
    "# ========= Step 2. 映射 topk -> case_id =========\n",
    "with open(TOPK_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    topk_data = json.load(f)\n",
    "\n",
    "pairs = []\n",
    "M = len(topk_data)\n",
    "total_slots = matched = miss = 0\n",
    "for i, rec in enumerate(topk_data):\n",
    "    ids = []\n",
    "    for c in rec.get(TOPK_LIST_FIELD, [])[:TOPK_K]:\n",
    "        total_slots += 1\n",
    "        t = c.get(TOPK_TEXT_FIELD, \"\")\n",
    "        cid = text2id.get(KEY(t)) if isinstance(t, str) and t else None\n",
    "        if cid is None:\n",
    "            miss += 1; ids.append(None)\n",
    "        else:\n",
    "            matched += 1; ids.append(int(cid))\n",
    "    pairs.append({\"test_index\": i, \"topk_case_ids\": ids})\n",
    "\n",
    "with open(OUT_PAIRS, \"w\", encoding=\"utf-8\") as f:\n",
    "    for row in pairs:\n",
    "        f.write(json.dumps(row, ensure_ascii=False) + \"\\n\")\n",
    "print(f\"[映射topk] 覆盖率={matched/total_slots:.2%}  (matched={matched}, total={total_slots})\")\n",
    "print(f\"→ 写出 {OUT_PAIRS}\")\n",
    "\n",
    "# ========= Step 3. 对所有 id 统计使用次数/命中率 =========\n",
    "lists = [r[\"topk_case_ids\"] for r in pairs]\n",
    "flat_ids = [cid for lst in lists for cid in lst if cid is not None]\n",
    "\n",
    "hits_counter = Counter(flat_ids)      # 槽位计数\n",
    "unique_queries_counter = Counter()    # 被多少个不同样本命中\n",
    "top1_counter = Counter()\n",
    "pos_counters = defaultdict(Counter)\n",
    "\n",
    "for lst in lists:\n",
    "    uniq = {cid for cid in lst if cid is not None}\n",
    "    unique_queries_counter.update(uniq)\n",
    "    if lst and lst[0] is not None:\n",
    "        top1_counter[lst[0]] += 1\n",
    "    for pos, cid in enumerate(lst):\n",
    "        if cid is not None:\n",
    "            pos_counters[cid][pos] += 1\n",
    "\n",
    "total_slots = sum(len(x) for x in lists)\n",
    "\n",
    "with open(OUT_COUNTS, \"w\", newline=\"\", encoding=\"utf-8\") as f:\n",
    "    w = csv.writer(f)\n",
    "    w.writerow([\"case_id\",\"hits\",\"unique_queries\",\"top1_hits\",\"pos0\",\"pos1\",\"pos2\",\n",
    "                \"rate_over_slots\",\"rate_per_query\",\"top1_rate\",\"type\",\"text_head\"])\n",
    "    for cid in range(N):\n",
    "        h  = hits_counter.get(cid, 0)\n",
    "        uq = unique_queries_counter.get(cid, 0)\n",
    "        t1 = top1_counter.get(cid, 0)\n",
    "        p0 = pos_counters[cid].get(0, 0)\n",
    "        p1 = pos_counters[cid].get(1, 0)\n",
    "        p2 = pos_counters[cid].get(2, 0)\n",
    "        w.writerow([\n",
    "            cid, h, uq, t1, p0, p1, p2,\n",
    "            h/total_slots if total_slots else 0.0,\n",
    "            uq/M if M else 0.0,\n",
    "            t1/M if M else 0.0,\n",
    "            id2type.get(cid, \"\"),\n",
    "            (id2text.get(cid, \"\")[:80].replace(\"\\n\",\" \"))\n",
    "        ])\n",
    "print(f\"→ 写出 {OUT_COUNTS}\")\n",
    "\n",
    "# ========= Step 4. 合并 embedding，准备做聚类图 =========\n",
    "# 说明：聚类/降维需要向量，这里直接复用你已有的 embedding 文件\n",
    "with open(EMB_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    emb_rows = json.load(f)\n",
    "\n",
    "# 建立 key -> embedding 映射（与上面同一 KEY 规则）\n",
    "key2emb = {}\n",
    "for it in emb_rows:\n",
    "    t = it.get(TRAIN_TEXT_FIELD, \"\") or it.get(\"posts\", \"\")\n",
    "    if isinstance(t, str) and t and \"embedding\" in it:\n",
    "        k = KEY(t)\n",
    "        key2emb[k] = it[\"embedding\"]\n",
    "\n",
    "# 组装：cid、type、text、hits、embedding\n",
    "merged = []\n",
    "miss_emb = 0\n",
    "for cid in range(N):\n",
    "    t = id2text.get(cid, \"\")\n",
    "    k = KEY(t)\n",
    "    emb = key2emb.get(k)\n",
    "    if emb is None:\n",
    "        miss_emb += 1\n",
    "        continue\n",
    "    merged.append({\n",
    "        \"case_id\": cid,\n",
    "        \"type\": id2type.get(cid, \"\"),\n",
    "        \"text\": t,\n",
    "        \"hits\": hits_counter.get(cid, 0),\n",
    "        \"embedding\": emb\n",
    "    })\n",
    "\n",
    "with open(OUT_MERGED, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(merged, f, ensure_ascii=False, indent=2)\n",
    "print(f\"→ 合并完成：{len(merged)} 条（缺失向量 {miss_emb} 条），写出 {OUT_MERGED}\")\n",
    "\n",
    "# ========= Step 5. 降维 + 绘图（高使用高亮）=========\n",
    "# 准备矩阵\n",
    "X = np.array([m[\"embedding\"] for m in merged], dtype=np.float32)\n",
    "hits = np.array([m[\"hits\"] for m in merged], dtype=np.int32)\n",
    "types = [m[\"type\"] for m in merged]\n",
    "cids  = [m[\"case_id\"] for m in merged]\n",
    "\n",
    "# 先 PCA->50，再尝试 UMAP；没有 UMAP 就退化到 PCA2 或 TSNE(可能慢)\n",
    "Z_2d = None\n",
    "try:\n",
    "    import umap\n",
    "    X50 = PCA(n_components=50, random_state=0).fit_transform(X)\n",
    "    reducer = umap.UMAP(n_neighbors=30, min_dist=0.1, metric=\"cosine\", random_state=0)\n",
    "    Z_2d = reducer.fit_transform(X50)\n",
    "except Exception:\n",
    "    try:\n",
    "        Z_2d = PCA(n_components=2, random_state=0).fit_transform(X)\n",
    "    except Exception:\n",
    "        from sklearn.manifold import TSNE\n",
    "        Z_2d = TSNE(n_components=2, random_state=0, perplexity=30, init=\"pca\").fit_transform(X)\n",
    "\n",
    "# 高亮规则：命中 TopN 或 hits >= 分位数阈值\n",
    "TOPN_HIGHLIGHT = 300\n",
    "Q_PERCENTILE   = 99  # 或设为 None 只用 TopN\n",
    "thr = np.percentile(hits, Q_PERCENTILE) if Q_PERCENTILE is not None else None\n",
    "order = np.argsort(-hits)\n",
    "mask_high = np.zeros_like(hits, dtype=bool)\n",
    "mask_high[order[:TOPN_HIGHLIGHT]] = True\n",
    "if thr is not None:\n",
    "    mask_high |= (hits >= thr)\n",
    "\n",
    "# 绘图\n",
    "plt.figure(figsize=(9, 7), dpi=150)\n",
    "# 背景（低使用）\n",
    "plt.scatter(Z_2d[~mask_high, 0], Z_2d[~mask_high, 1],\n",
    "            s=5, alpha=0.25, linewidths=0, label=\"others\")\n",
    "# 高使用\n",
    "plt.scatter(Z_2d[mask_high, 0], Z_2d[mask_high, 1],\n",
    "            s=18, alpha=0.9, linewidths=0.5, edgecolors=\"k\", label=\"high-usage\")\n",
    "\n",
    "plt.title(\"Casebank clustering (high-usage highlighted)\")\n",
    "plt.legend(loc=\"best\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_PLOT)\n",
    "print(f\"✅ 聚类图已保存：{OUT_PLOT}\")\n",
    "\n",
    "#（可选）标注前几十个典型点\n",
    "ANNOTATE_TOPK = 40\n",
    "for idx in order[:ANNOTATE_TOPK]:\n",
    "    x, y = Z_2d[idx]\n",
    "    lbl = f\"{cids[idx]} | {types[idx]}\"\n",
    "    plt.text(x, y, lbl, fontsize=6)\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_PLOT.replace(\".png\",\"_annotated.png\"))\n",
    "print(f\"✅ 已保存带标注版本：{OUT_PLOT.replace('.png','_annotated.png')}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "66a3c774",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top cases by hits:\n",
      " 1. cid=5851  hits=13  uq=13  top1=4  pos[0/1/2]=(4/5/4)\n",
      " 2. cid=6277  hits=11  uq=11  top1=7  pos[0/1/2]=(7/1/3)\n",
      " 3. cid=10021  hits=10  uq=10  top1=6  pos[0/1/2]=(6/4/0)\n",
      " 4. cid=698  hits=9  uq=8  top1=5  pos[0/1/2]=(5/3/1)\n",
      " 5. cid=3792  hits=9  uq=9  top1=3  pos[0/1/2]=(3/3/3)\n",
      " 6. cid=18467  hits=9  uq=9  top1=3  pos[0/1/2]=(3/2/4)\n",
      " 7. cid=20808  hits=9  uq=9  top1=3  pos[0/1/2]=(3/2/4)\n",
      " 8. cid=1866  hits=9  uq=9  top1=2  pos[0/1/2]=(2/6/1)\n",
      " 9. cid=6704  hits=9  uq=9  top1=2  pos[0/1/2]=(2/1/6)\n",
      "10. cid=10748  hits=8  uq=8  top1=7  pos[0/1/2]=(7/0/1)\n",
      "✅ 已保存：topN_cases_with_post.txt\n",
      "✅ 已保存：all_cases_with_post.jsonl\n"
     ]
    }
   ],
   "source": [
    "# print_topN_with_post.py\n",
    "# -*- coding: utf-8 -*-\n",
    "import json, csv\n",
    "\n",
    "COUNTS_FILE = \"case_usage_all.csv\"   # 上一步输出\n",
    "TRAIN_FILE  = \"train.json\"           # 权威原文库\n",
    "TRAIN_TEXT_FIELD = \"posts_cleaned\"   # 如需换字段，改这里\n",
    "\n",
    "TOPN = 10                            # 要查看的前N\n",
    "EXPORT_ALL_JSONL = True              # 是否导出所有id到 JSONL（含完整post）\n",
    "ALL_JSONL_FILE   = \"all_cases_with_post.jsonl\"\n",
    "TOPN_TXT_FILE    = \"topN_cases_with_post.txt\"\n",
    "\n",
    "# 1) 读 train.json，建立 id -> post / type\n",
    "with open(TRAIN_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    train = json.load(f)\n",
    "id2post = {}\n",
    "id2type = {}\n",
    "for i, it in enumerate(train):\n",
    "    id2post[i] = it.get(TRAIN_TEXT_FIELD, \"\") or it.get(\"posts\", \"\") or \"\"\n",
    "    id2type[i] = it.get(\"type\", \"\")\n",
    "\n",
    "# 2) 读统计表\n",
    "rows = []\n",
    "with open(COUNTS_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    r = csv.DictReader(f)\n",
    "    for row in r:\n",
    "        row[\"case_id\"] = int(row[\"case_id\"])\n",
    "        row[\"hits\"] = int(row[\"hits\"])\n",
    "        row[\"unique_queries\"] = int(row[\"unique_queries\"])\n",
    "        row[\"top1_hits\"] = int(row[\"top1_hits\"])\n",
    "        row[\"pos0\"] = int(row[\"pos0\"])\n",
    "        row[\"pos1\"] = int(row[\"pos1\"])\n",
    "        row[\"pos2\"] = int(row[\"pos2\"])\n",
    "        rows.append(row)\n",
    "\n",
    "# 3) 排序：hits desc -> top1_hits desc -> unique_queries desc\n",
    "rows.sort(key=lambda x: (x[\"hits\"], x[\"top1_hits\"], x[\"unique_queries\"]), reverse=True)\n",
    "\n",
    "# 4) 终端打印摘要 + 写入带post的TXT\n",
    "print(\"Top cases by hits:\")\n",
    "out_lines = [\"Top cases by hits:\"]\n",
    "for i, r in enumerate(rows[:TOPN], 1):\n",
    "    cid = r[\"case_id\"]\n",
    "    post = id2post.get(cid, \"\")\n",
    "    line = (f\"{i:>2}. cid={cid}  \"\n",
    "            f\"hits={r['hits']}  uq={r['unique_queries']}  top1={r['top1_hits']}  \"\n",
    "            f\"pos[0/1/2]=({r['pos0']}/{r['pos1']}/{r['pos2']})\")\n",
    "    print(line)\n",
    "    out_lines.append(line)\n",
    "    out_lines.append(f\"[type]={id2type.get(cid,'')}\")\n",
    "    out_lines.append(post)          # 完整 posts_cleaned\n",
    "    out_lines.append(\"\")            # 空行分隔\n",
    "\n",
    "with open(TOPN_TXT_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "    f.write(\"\\n\".join(out_lines))\n",
    "print(f\"✅ 已保存：{TOPN_TXT_FILE}\")\n",
    "\n",
    "# 5) （可选）导出所有id到JSONL，便于后续做可视化/检索\n",
    "if EXPORT_ALL_JSONL:\n",
    "    with open(ALL_JSONL_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "        for r in rows:\n",
    "            cid = r[\"case_id\"]\n",
    "            rec = {\n",
    "                \"case_id\": cid,\n",
    "                \"type\": id2type.get(cid, \"\"),\n",
    "                \"post\": id2post.get(cid, \"\"),\n",
    "                \"hits\": r[\"hits\"],\n",
    "                \"unique_queries\": r[\"unique_queries\"],\n",
    "                \"top1_hits\": r[\"top1_hits\"],\n",
    "                \"pos0\": r[\"pos0\"], \"pos1\": r[\"pos1\"], \"pos2\": r[\"pos2\"],\n",
    "                \"rate_over_slots\": float(r[\"rate_over_slots\"]),\n",
    "                \"rate_per_query\": float(r[\"rate_per_query\"]),\n",
    "                \"top1_rate\": float(r[\"top1_rate\"]),\n",
    "            }\n",
    "            f.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
    "    print(f\"✅ 已保存：{ALL_JSONL_FILE}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "06905da3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "===== ISTP | top4 by hits =====\n",
      "1. cid=18467 hits=9 uq=9 top1=3\n",
      "people say istps cold unemotional lazy thinking process feeling display fix problem instead talking want drama elsewhere give task challenge space alone recharge show made need explain every move thought observing adapting core emotion come get measured analyzed either used discarded fuss extra noise yeah sometimes get impatient social nicety small talk waste time serve purpose prefer straight answer action want understand watch handle pressure fix something broken real story end fitting mold proving something others living life make sense cutting bullshit\n",
      "\n",
      "2. cid=15075 hits=5 uq=5 top1=4\n",
      "cut noise something need fixing figure broken fix waste time whining explaining people talk much little chase approval drama keep circle small quality quantity say done mean game feeling fact action want advice ready work otherwise save breath life short nonsense overthink past future focus front tool task result trust judgment anyone else screw fix excuse apology unless necessary people think istps cold maybe real waste energy thing make sense add value emotion data dictator stay sharp keep moving solve problem\n",
      "\n",
      "3. cid=587 hits=5 uq=5 top1=3\n",
      "people overcomplicate istp behavior need explain every move feeling make sense drop emotion checklist data point process ignore depending situation expect dwell stuff endlessly solve adapt move relationship deal worth effort lean mess waste time fixing something hold communication direct silent guesswork needed yeah like thing practical good tool solid plan clear outcome mean cold efficient drama pointless conversation act observe adjust repeat\n",
      "\n",
      "4. cid=16409 hits=5 uq=5 top1=3\n",
      "people say istps detached unemotional maybe feel waste energy explaining feeling trying justify action speak louder matter handle quietly efficiently move overthinking get way want know stand watch say also yeah get frustrated fluff many word much drama cut noise get point babysit emotion decode cryptic hint something say straight fix relationship complicated simple chase commit real looking soap opera someone get roll punch compatibility matching feeling syncing action respect finally hobby keep grounded building tearing figuring stuff work pastime process world get probably scene\n",
      "\n",
      "\n",
      "===== ISFJ | top4 by hits =====\n",
      "1. cid=698 hits=9 uq=8 top1=5\n",
      "sometimes find quietly worrying people care even everything seems fine surface like gentle tug heart reminding check present offer listening ear always felt part role hold space others reliable need stability help carry burden little always easy balance feeling learned small act kindness even word encouragement mean world honestly think lot strength come quietly showing day day even hard believe love duty look like lately trying gentle reminding okay ask support isfjs tend caretaker human sometimes tired sometimes unsure learning taking moment rest selfish necessary keep people love anyone else found way balance caring deeply others also protecting peace love hear helped even smallest thing\n",
      "\n",
      "2. cid=20619 hits=5 uq=5 top1=3\n",
      "never thought starting dream journal never really held interest dream forget almost immediately feel worse also pretty weirded recently thinking starting journal try organize thought bit maybe spend much time thought sure expect always bit entirely sure question make sense feel best tool connecting people able easily sympathize relate people honesty man debate religion based around loaded question atheism always fun stuff try live value enough god live yeah understand fairly well bit problem well question thing recreation alone thing trouble expressing emotion manage anger resentment daily life also top favorite musician band aaron worse thinking paranoid knowing aaron safe abe aaron imagine way thing could considered anywhere plastic beach album enjoy front back really great stuff demon day also nice track city god would lovely selection would love watch movie rarely like discus much less personal feeling really nail specific reason dislike much except problem relying people asking gay sorry thing fault man equivalent man nut etc everything happens reason kind sweeping generalization going disagree stand harry potter deathly hallows opinion great example good book bad ending really sure know answer question obviously nothing compare suppose lot trait isfj might seen typically feminine know mystifying someone please explain mention like zanimus work crippling introversion well suited hanging cool crowd never however would like inquire supposed delight seeing isfj male let mistake think need tighter security door unsure dishonesty disloyalty usually maintain eye contact close friend well bordering somewhat weird stare try much feel make uncomfortable put idea marriage hand feel like would much like child honestly early really thinking much told witty sense humor around close friend usually silly serious however would agree general trend serious agree self flattery also quite astounding also use every diversionary tactic conceived back completely bogus claim thread completely seeking romantic relationship actually want friend connect personal level kind pressure people seek sexual relationship seems teenager comparing middle aged adult still feel like adolescent child rightfully time laughing favorite movie would probably either goodfellas reservoir dog though violent person tongue find battle royale hunger game good story absolutely horrible ending typing katniss would agree istp prediction people made harry potter ronald weasley winne pooh piglet dragon ball krillin cautious fast cautious follow speed limit exactly follow speed limit closely would absolutely hate get ticket likely cut people man forget jeremy classic example another popular example find tri type wing confusing see part realize title might misleading honestly think different one always found interesting song taboo subject pedophilia stalking incest spam world everyone proud woman cool congrats legally drink awful stuff throw happen drink much wooo really whole problem stimulus never really go away put friend something irritates might dwell bit forget yes aware even sure classify adele lana del rey usually cup tea wink got disappointed enjoyed first second third much currently reading second artemis fowl sure feel right really style especially loathe dubstep support anything take attention lady gaga opinion symbiotic power interesting give way plot development later story perfect example would venom spider man universe mmmmmmmm ekac yet really explode keeping many feeling becoming increasingly worried might understand mean knowing whether concern seem problem whenever faced confrontation hurt feeling bring try push away forget aware far healthiest way funny also find conversation imaginary version people know maybe somewhat common also completely obsessive idea song grab like add something post pretty much sum everything could said say time honesty really biggest component getting open yes view everyone real literal sense also large gap present real discussion relationship completely non threatening provide small example limited experience matter interest girl caught remembering small detail thing get satisfaction either held gun head guess pick masochism always imagined fuck impressive looking isfj enneagram maybe looked much kind caring punctual silent boring may reflect certain type find much creature habit really bothered plan changed even worse thing happen completely unexpectedly never physical confrontation kind pissed want fight somebody watch whole lot television walking dead show way plan watch however also enjoy met mother colbert report son hmm tell sure look like puzzle might multiple solution moment think get least prop\n",
      "\n",
      "3. cid=7459 hits=5 uq=5 top1=1\n",
      "hello everyone hope well today wanted share little something heart lately sometimes find caught wanting please everyone needing take care emotional well delicate balance realized important set gentle boundary care want able show fully people love without feeling drained lately trying give permission say need even feel uncomfortable first easy remind caring part responsibility count want steady dependable mean recognizing need recharge also slowly organizing space little sanctuary really comforting process putting thing order creating peaceful environment feel like small meaningful way care family thank sharing thought reassuring know alone struggle support one another kindness understanding wishing calm gentle day isfj\n",
      "\n",
      "4. cid=16106 hits=5 uq=5 top1=1\n",
      "hello friend lately reflecting much value quiet moment time pause really listen others heart feel important honor feeling even small shy sometimes find hesitant share inside worried might burden someone else learning holding everything help anyone least also noticed deeply rewarding remember little detail people care favorite song story told something made day better small way showing paying attention matter hope nurture connection gentle patience kindness honestly world feel kinder take care anyone else find comfort routine also crave little space breathe spontaneous trying find balance always easy believe responsible reliable mean also gentle open even thank part thoughtful warm community quiet strength cherish dearly wishing peace steady heart isfj soul\n",
      "\n",
      "\n",
      "✅ 已保存: top4_posts_ISTP_ISFJ.txt\n"
     ]
    }
   ],
   "source": [
    "# top4_posts_ISTP_ISFJ_from_jsonl.py\n",
    "import json\n",
    "\n",
    "JSONL = \"all_cases_with_post.jsonl\"\n",
    "TYPES = [\"ISTP\", \"ISFJ\"]\n",
    "TOPK  = 4\n",
    "\n",
    "# 读入\n",
    "items = []\n",
    "with open(JSONL, \"r\", encoding=\"utf-8\") as f:\n",
    "    for line in f:\n",
    "        items.append(json.loads(line))\n",
    "\n",
    "# 逐 type 过滤 + 排序 + 取前K\n",
    "out_lines = []\n",
    "for t in TYPES:\n",
    "    sub = [x for x in items if x.get(\"type\") == t]\n",
    "    sub.sort(key=lambda x: (x.get(\"hits\", 0), x.get(\"top1_hits\", 0), x.get(\"unique_queries\", 0)), reverse=True)\n",
    "    top = sub[:TOPK]\n",
    "    out_lines.append(f\"===== {t} | top{TOPK} by hits =====\")\n",
    "    for i, r in enumerate(top, 1):\n",
    "        out_lines.append(f\"{i}. cid={r['case_id']} hits={r['hits']} uq={r.get('unique_queries',0)} top1={r.get('top1_hits',0)}\")\n",
    "        out_lines.append(r.get(\"post\", \"\"))  # posts_cleaned\n",
    "        out_lines.append(\"\")\n",
    "    out_lines.append(\"\")\n",
    "\n",
    "# 打印 + 保存\n",
    "print(\"\\n\".join(out_lines))\n",
    "with open(\"top4_posts_ISTP_ISFJ.txt\", \"w\", encoding=\"utf-8\") as f:\n",
    "    f.write(\"\\n\".join(out_lines))\n",
    "print(\"✅ 已保存: top4_posts_ISTP_ISFJ.txt\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "467bbd89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[load] 读取 merged：27223 条 ← case_meta_with_hits.json\n",
      "[save] 图已保存：casebank_umap_highusage.pdf / casebank_umap_highusage.svg / casebank_umap_highusage.png\n",
      "[save] 带标注版本：casebank_umap_highusage_annotated.png  (fallback offsets)\n",
      "\n",
      "[FIGURE CAPTION SUGGESTION]\n",
      "2D visualization of casebank embeddings using PCA(50)→UMAP(2) (n_neighbors=30, min_dist=0.1, cosine, random_state=0). Orange points denote high-usage cases (Top-300 ∪ 99th percentile). Labels show Top-10 by hits with overlap avoidance.\n"
     ]
    }
   ],
   "source": [
    "# plot_casebank_highusage_top10.py\n",
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "功能：\n",
    "- 读取合并文件 case_meta_with_hits.json（若不存在则自动由 counts+embeddings+train 构建）\n",
    "- PCA(50)->UMAP(2)（没装 umap 就退化到 PCA2；再不行用 t-SNE）\n",
    "- 背景点淡化，高使用点高亮\n",
    "- 仅标注 Top10，并用 adjustText 自动避让（无该库则回退到偏移+箭头）\n",
    "- 导出 PDF/SVG/600dpi PNG\n",
    "\n",
    "可修改参数见【参数区】。\n",
    "\"\"\"\n",
    "import os, json, csv, re, random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# ======================= 参数区 =======================\n",
    "MERGED_JSON = \"case_meta_with_hits.json\"                 # 优先读取（不存在则自动构建）\n",
    "COUNTS_CSV  = \"case_usage_all.csv\"                       # 构建所需：命中统计\n",
    "EMB_JSON    = \"casebank_A_train_80_with_embeddings.json\" # 构建所需：embedding 文件\n",
    "TRAIN_JSON  = \"train.json\"                               # 构建所需：原始文本\n",
    "TEXT_FIELD  = \"posts_cleaned\"                            # 文本字段名（train/embedding）\n",
    "\n",
    "TOPN_HIGHLIGHT = 300     # 高使用点高亮：命中次数Top-N（None仅按分位阈值）\n",
    "Q_PERCENTILE   = 99      # 或命中次数≥该分位也高亮（None关闭）\n",
    "LABEL_TOPN     = 10      # 只标注前10个\n",
    "LABEL_SCOPE    = \"global\"  # \"global\" 在全体里选Top10；\"mask\" 只在高使用里选Top10\n",
    "LABEL_FMT      = \"{cid}|{typ}\"  # 标注文本格式\n",
    "SEED           = 0\n",
    "\n",
    "OUT_PDF       = \"casebank_umap_highusage.pdf\"\n",
    "OUT_SVG       = \"casebank_umap_highusage.svg\"\n",
    "OUT_PNG       = \"casebank_umap_highusage.png\"\n",
    "OUT_PNG_ANNOT = \"casebank_umap_highusage_annotated.png\"\n",
    "# =====================================================\n",
    "\n",
    "# 工具函数\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def key_prefix80_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "\n",
    "def set_seeds(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "\n",
    "try:\n",
    "    import umap  # noqa: F401\n",
    "    HAS_UMAP = True\n",
    "except Exception:\n",
    "    HAS_UMAP = False\n",
    "\n",
    "def build_merged_from_sources():\n",
    "    \"\"\"当 MERGED_JSON 不存在时，从 COUNTS_CSV + EMB_JSON + TRAIN_JSON 合并。\"\"\"\n",
    "    # 1) train：id->type/text/key\n",
    "    with open(TRAIN_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        train = json.load(f)\n",
    "    id2type, id2text, id2key = {}, {}, {}\n",
    "    for i, it in enumerate(train):\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\", \"\")\n",
    "        id2type[i] = it.get(\"type\", \"\")\n",
    "        id2text[i] = t\n",
    "        id2key[i]  = key_prefix80_space(t)\n",
    "\n",
    "    # 2) embeddings：key->emb\n",
    "    with open(EMB_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        emb_rows = json.load(f)\n",
    "    key2emb = {}\n",
    "    for it in emb_rows:\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\", \"\")\n",
    "        if isinstance(t, str) and t and \"embedding\" in it:\n",
    "            key2emb[key_prefix80_space(t)] = it[\"embedding\"]\n",
    "\n",
    "    # 3) 命中统计：cid->hits\n",
    "    hits = {}\n",
    "    with open(COUNTS_CSV, \"r\", encoding=\"utf-8\") as f:\n",
    "        r = csv.DictReader(f)\n",
    "        for row in r:\n",
    "            cid = int(row[\"case_id\"])\n",
    "            hits[cid] = int(row[\"hits\"])\n",
    "\n",
    "    # 4) 合并\n",
    "    merged, miss_emb = [], 0\n",
    "    for cid in range(len(train)):\n",
    "        k = id2key.get(cid, None)\n",
    "        emb = key2emb.get(k)\n",
    "        if emb is None:\n",
    "            miss_emb += 1\n",
    "            continue\n",
    "        merged.append({\n",
    "            \"case_id\": cid,\n",
    "            \"type\": id2type.get(cid, \"\"),\n",
    "            \"text\": id2text.get(cid, \"\"),\n",
    "            \"hits\": hits.get(cid, 0),\n",
    "            \"embedding\": emb\n",
    "        })\n",
    "    with open(MERGED_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(merged, f, ensure_ascii=False, indent=2)\n",
    "    print(f\"[merge] 构建 merged：{len(merged)} 条（缺失向量 {miss_emb} 条）→ {MERGED_JSON}\")\n",
    "    return merged\n",
    "\n",
    "def load_merged():\n",
    "    if os.path.exists(MERGED_JSON):\n",
    "        with open(MERGED_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "            merged = json.load(f)\n",
    "        print(f\"[load] 读取 merged：{len(merged)} 条 ← {MERGED_JSON}\")\n",
    "        return merged\n",
    "    print(\"[load] 未找到 case_meta_with_hits.json，自动从三源合并构建…\")\n",
    "    return build_merged_from_sources()\n",
    "\n",
    "def main():\n",
    "    set_seeds(SEED)\n",
    "    merged = load_merged()\n",
    "    if not merged:\n",
    "        raise RuntimeError(\"没有可用数据，请检查输入路径。\")\n",
    "\n",
    "    # 数据取出\n",
    "    X = np.array([m[\"embedding\"] for m in merged], dtype=np.float32)\n",
    "    hits = np.array([m[\"hits\"] for m in merged], dtype=np.int32)\n",
    "    types = [m.get(\"type\",\"\") for m in merged]\n",
    "    cids  = [m[\"case_id\"] for m in merged]\n",
    "\n",
    "    # 降维：PCA(50)->UMAP(2)；无 UMAP 则 PCA2；再不行用 TSNE\n",
    "    try:\n",
    "        X50 = PCA(n_components=min(50, max(2, X.shape[1]-1)), random_state=SEED).fit_transform(X)\n",
    "        if HAS_UMAP:\n",
    "            import umap\n",
    "            reducer = umap.UMAP(n_neighbors=30, min_dist=0.1, metric=\"cosine\", random_state=SEED)\n",
    "            Z_2d = reducer.fit_transform(X50)\n",
    "        else:\n",
    "            Z_2d = PCA(n_components=2, random_state=SEED).fit_transform(X50)\n",
    "    except Exception:\n",
    "        from sklearn.manifold import TSNE\n",
    "        Z_2d = TSNE(n_components=2, random_state=SEED, perplexity=30, init=\"pca\").fit_transform(X)\n",
    "\n",
    "    # 高使用集合 mask\n",
    "    N = len(hits)\n",
    "    order = np.argsort(-hits)\n",
    "    mask = np.zeros(N, dtype=bool)\n",
    "    if TOPN_HIGHLIGHT:\n",
    "        mask[order[:min(TOPN_HIGHLIGHT, N)]] = True\n",
    "    if Q_PERCENTILE is not None:\n",
    "        thr = float(np.percentile(hits, Q_PERCENTILE))\n",
    "        mask |= (hits >= thr)\n",
    "\n",
    "    # 绘图（论文风格）\n",
    "    fig = plt.figure(figsize=(6.5, 4.5), dpi=300)  # 单栏尺寸\n",
    "    ax = plt.gca()\n",
    "    ax.scatter(Z_2d[~mask, 0], Z_2d[~mask, 1], s=4,  alpha=0.15, linewidths=0, label=\"others\")\n",
    "    ax.scatter(Z_2d[mask, 0],  Z_2d[mask, 1],  s=18, alpha=0.9,  linewidths=0.4, edgecolors=\"k\", label=\"high-usage\")\n",
    "    ax.set_xlabel(\"UMAP-1\"); ax.set_ylabel(\"UMAP-2\")\n",
    "    ax.legend(frameon=False, loc=\"upper left\")\n",
    "    ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(OUT_PDF); plt.savefig(OUT_SVG); plt.savefig(OUT_PNG, dpi=600)\n",
    "    print(f\"[save] 图已保存：{OUT_PDF} / {OUT_SVG} / {OUT_PNG}\")\n",
    "\n",
    "    # —— 只标注 Top10，并尽量避免遮挡 —— #\n",
    "    if LABEL_TOPN and LABEL_TOPN > 0:\n",
    "        if LABEL_SCOPE == \"mask\":\n",
    "            mask_idx  = np.where(mask)[0]\n",
    "            top_idx   = mask_idx[np.argsort(-hits[mask])][:min(LABEL_TOPN, len(mask_idx))]\n",
    "        else:  # global\n",
    "            top_idx = order[:min(LABEL_TOPN, N)]\n",
    "\n",
    "        # 优先使用 adjustText 自动避让\n",
    "        used_adjust = False\n",
    "        try:\n",
    "            from adjustText import adjust_text\n",
    "            texts = []\n",
    "            for idx in top_idx:\n",
    "                x, y = Z_2d[idx]\n",
    "                lbl = LABEL_FMT.format(cid=cids[idx], typ=types[idx])\n",
    "                t = ax.text(\n",
    "                    x, y, lbl, fontsize=7, zorder=5,\n",
    "                    bbox=dict(facecolor=\"white\", alpha=0.85, lw=0, pad=0.6)\n",
    "                )\n",
    "                texts.append(t)\n",
    "            adjust_text(\n",
    "                texts,\n",
    "                only_move={'points': 'y', 'text': 'xy'},\n",
    "                expand_points=(1.2, 1.2), expand_text=(1.2, 1.2),\n",
    "                arrowprops=dict(arrowstyle='-', lw=0.6, color='0.25', alpha=0.8)\n",
    "            )\n",
    "            used_adjust = True\n",
    "        except Exception:\n",
    "            # 回退方案：固定少量偏移 + 细箭头\n",
    "            offsets = [(12,6), (-12,6), (12,-6), (-12,-6), (18,0),\n",
    "                       (-18,0), (0,10), (0,-10), (20,8), (-20,8)]\n",
    "            for off, idx in zip(offsets, top_idx):\n",
    "                x, y = Z_2d[idx]\n",
    "                lbl = LABEL_FMT.format(cid=cids[idx], typ=types[idx])\n",
    "                ax.annotate(\n",
    "                    lbl, xy=(x, y), xycoords='data',\n",
    "                    xytext=off, textcoords='offset points',\n",
    "                    fontsize=7, zorder=5,\n",
    "                    bbox=dict(facecolor=\"white\", alpha=0.85, lw=0, pad=0.6),\n",
    "                    arrowprops=dict(arrowstyle='-', lw=0.6, color='0.25', alpha=0.8)\n",
    "                )\n",
    "\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(OUT_PNG_ANNOT, dpi=600)\n",
    "        print(f\"[save] 带标注版本：{OUT_PNG_ANNOT}  ({'adjustText' if used_adjust else 'fallback offsets'})\")\n",
    "\n",
    "    # 图注模板（粘到论文里）\n",
    "    desc = (\n",
    "        \"2D visualization of casebank embeddings using PCA(50)→UMAP(2) \"\n",
    "        f\"(n_neighbors=30, min_dist=0.1, cosine, random_state={SEED}). \"\n",
    "        f\"Orange points denote high-usage cases (Top-{TOPN_HIGHLIGHT} ∪ {Q_PERCENTILE}th percentile). \"\n",
    "        f\"Labels show Top-{LABEL_TOPN} by hits with overlap avoidance.\"\n",
    "    )\n",
    "    print(\"\\n[FIGURE CAPTION SUGGESTION]\\n\" + desc)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "02d9974a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[load] 27223 条 ← case_meta_with_hits.json\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[t-SNE] Computing 91 nearest neighbors...\n",
      "[t-SNE] Indexed 27223 samples in 0.001s...\n",
      "[t-SNE] Computed neighbors for 27223 samples in 0.455s...\n",
      "[t-SNE] Computed conditional probabilities for sample 1000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 2000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 3000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 4000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 5000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 6000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 7000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 8000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 9000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 10000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 11000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 12000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 13000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 14000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 15000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 16000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 17000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 18000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 19000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 20000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 21000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 22000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 23000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 24000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 25000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 26000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27223 / 27223\n",
      "[t-SNE] Mean sigma: 0.072341\n",
      "[t-SNE] KL divergence after 250 iterations with early exaggeration: 81.424820\n",
      "[t-SNE] KL divergence after 1500 iterations: 1.738065\n",
      "✅ 保存：types_tsne_starburst_idonly.pdf / types_tsne_starburst_idonly.svg / types_tsne_starburst_idonly.png\n"
     ]
    }
   ],
   "source": [
    "# plot_types_tsne_starburst_top1_id_only.py\n",
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "t-SNE（天女散花）可视化：\n",
    "- 读取 case_meta_with_hits.json（若不存在则自动由 counts+embeddings+train 合并生成）\n",
    "- 16个MBTI类型分别着色（颜色更鲜明），底层点较大且不透明度高一些\n",
    "- 每个type仅高亮命中最多的1个样本（黑色细边），标签只显示 case_id\n",
    "- 导出 PDF/SVG/600dpi PNG\n",
    "\"\"\"\n",
    "\n",
    "import os, json, csv, re, random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# ====== 路径与字段 ======\n",
    "MERGED_JSON = \"case_meta_with_hits.json\"                 # 优先读取\n",
    "COUNTS_CSV  = \"case_usage_all.csv\"                       # 备选合并来源\n",
    "EMB_JSON    = \"casebank_A_train_80_with_embeddings.json\" # 备选合并来源\n",
    "TRAIN_JSON  = \"train.json\"                               # 备选合并来源\n",
    "TEXT_FIELD  = \"posts_cleaned\"\n",
    "\n",
    "# ====== t-SNE 参数（决定“天女散花”效果）======\n",
    "SEED = 0\n",
    "TSNE_PERPLEXITY = 30\n",
    "TSNE_EE         = 12          # early_exaggeration\n",
    "TSNE_LR         = 200         # learning_rate\n",
    "TSNE_ITER       = 1500\n",
    "PCA_DIM         = 50          # 先PCA再t-SNE更稳定\n",
    "\n",
    "# ====== 绘图样式（更鲜明） ======\n",
    "# 图例顺序（仅用于图例，不影响数据）\n",
    "MBTI_16 = [\n",
    "    \"ENFJ\",\"ENFP\",\"ENTJ\",\"ENTP\",\"ESFJ\",\"ESFP\",\"ESTJ\",\"ESTP\",\n",
    "    \"INFJ\",\"INFP\",\"INTJ\",\"INTP\",\"ISFJ\",\"ISFP\",\"ISTJ\",\"ISTP\"\n",
    "]\n",
    "# 底层点（所有样本，按type着色）\n",
    "BASE_SIZE  = 16\n",
    "BASE_ALPHA = 0.85\n",
    "BASE_EDGE_LW = 0.2            # 白色细边让群簇边界更清晰\n",
    "# 高亮点（各type命中Top-1）\n",
    "HL_SIZE    = 140\n",
    "HL_EDGE_LW = 1.0\n",
    "LABEL_FMT  = \"{cid}\"          # 只显示 case_id\n",
    "LABEL_FONTSIZE = 9\n",
    "LABEL_WITH_BOX = True         # 给标签加白底\n",
    "\n",
    "OUT_PDF = \"types_tsne_starburst_idonly.pdf\"\n",
    "OUT_SVG = \"types_tsne_starburst_idonly.svg\"\n",
    "OUT_PNG = \"types_tsne_starburst_idonly.png\"\n",
    "\n",
    "# ====== 小工具 ======\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def key_prefix80_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "\n",
    "def set_seeds(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "\n",
    "def vivid_palette_16():\n",
    "    \"\"\"\n",
    "    使用更鲜明的调色：以 tab10 为主色，补充 tab20 中对比较强的颜色，共16个。\n",
    "    \"\"\"\n",
    "    tab10 = list(plt.get_cmap(\"tab10\").colors)  # 10个、鲜明\n",
    "    tab20 = list(plt.get_cmap(\"tab20\").colors)\n",
    "    # 从 tab20 里挑 6 个对比度高的颜色补足 16\n",
    "    extra_idx = [1, 3, 5, 7, 9, 11]             # 交替抽取饱和度更高的一组\n",
    "    extras = [tab20[i] for i in extra_idx[:6]]\n",
    "    cols = tab10 + extras\n",
    "    return {t: cols[i] for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "def build_merged_from_sources():\n",
    "    # 1) train：id->type/text/key\n",
    "    with open(TRAIN_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        train = json.load(f)\n",
    "    id2type, id2text, id2key = {}, {}, {}\n",
    "    for i, it in enumerate(train):\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        id2type[i] = it.get(\"type\",\"\")\n",
    "        id2text[i] = t\n",
    "        id2key[i]  = key_prefix80_space(t)\n",
    "\n",
    "    # 2) embedding：key->emb\n",
    "    with open(EMB_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        emb_rows = json.load(f)\n",
    "    key2emb = {}\n",
    "    for it in emb_rows:\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        if isinstance(t, str) and t and \"embedding\" in it:\n",
    "            key2emb[key_prefix80_space(t)] = it[\"embedding\"]\n",
    "\n",
    "    # 3) hits：cid->hits\n",
    "    hits = {}\n",
    "    with open(COUNTS_CSV, \"r\", encoding=\"utf-8\") as f:\n",
    "        r = csv.DictReader(f)\n",
    "        for row in r:\n",
    "            hits[int(row[\"case_id\"])] = int(row[\"hits\"])\n",
    "\n",
    "    # 4) 合并\n",
    "    merged, miss = [], 0\n",
    "    for cid in range(len(train)):\n",
    "        emb = key2emb.get(id2key.get(cid,\"\"))\n",
    "        if emb is None:\n",
    "            miss += 1\n",
    "            continue\n",
    "        merged.append({\n",
    "            \"case_id\": cid,\n",
    "            \"type\": id2type.get(cid,\"\"),\n",
    "            \"text\": id2text.get(cid,\"\"),\n",
    "            \"hits\": hits.get(cid, 0),\n",
    "            \"embedding\": emb\n",
    "        })\n",
    "    with open(MERGED_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(merged, f, ensure_ascii=False, indent=2)\n",
    "    print(f\"[merge] 构建 merged：{len(merged)} 条（缺失向量 {miss}）→ {MERGED_JSON}\")\n",
    "    return merged\n",
    "\n",
    "def load_merged():\n",
    "    if os.path.exists(MERGED_JSON):\n",
    "        with open(MERGED_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "            merged = json.load(f)\n",
    "        print(f\"[load] {len(merged)} 条 ← {MERGED_JSON}\")\n",
    "        return merged\n",
    "    print(\"[load] 未找到 merged，自动从三源合并…\")\n",
    "    return build_merged_from_sources()\n",
    "\n",
    "# ====== 主流程 ======\n",
    "def main():\n",
    "    set_seeds(SEED)\n",
    "    merged = load_merged()\n",
    "    assert merged, \"没有可用数据\"\n",
    "\n",
    "    X   = np.array([m[\"embedding\"] for m in merged], dtype=np.float32)\n",
    "    H   = np.array([m[\"hits\"] for m in merged], dtype=np.int32)\n",
    "    T   = np.array([m.get(\"type\",\"\") for m in merged])\n",
    "    CID = np.array([m[\"case_id\"] for m in merged])\n",
    "\n",
    "    # 先 PCA 再 t-SNE（更稳定、更容易出“放射状”）\n",
    "    X50 = PCA(n_components=min(PCA_DIM, max(2, X.shape[1]-1)), random_state=SEED).fit_transform(X)\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        perplexity=TSNE_PERPLEXITY,\n",
    "        early_exaggeration=TSNE_EE,\n",
    "        learning_rate=TSNE_LR,\n",
    "        n_iter=TSNE_ITER,\n",
    "        init=\"pca\",\n",
    "        random_state=SEED,\n",
    "        angle=0.5,         # Barnes–Hut\n",
    "        verbose=1,\n",
    "    )\n",
    "    Z = tsne.fit_transform(X50)\n",
    "\n",
    "    # 每个type找“命中最多”的那个索引；若并列取第一个\n",
    "    top_idx_per_type = {}\n",
    "    for typ in np.unique(T):\n",
    "        mask = (T == typ)\n",
    "        if not mask.any(): \n",
    "            continue\n",
    "        idxs = np.where(mask)[0]\n",
    "        best = idxs[np.argmax(H[idxs])]\n",
    "        top_idx_per_type[typ] = best\n",
    "\n",
    "    # 颜色（更鲜明）\n",
    "    pal = vivid_palette_16()\n",
    "\n",
    "    # 绘图\n",
    "    fig = plt.figure(figsize=(7.5, 5.5), dpi=300)\n",
    "    ax = plt.gca()\n",
    "\n",
    "    # 底层点：颜色更“实”、稍大、加白色细边\n",
    "    for t in MBTI_16:\n",
    "        mt = (T == t)\n",
    "        if not mt.any(): \n",
    "            continue\n",
    "        ax.scatter(Z[mt,0], Z[mt,1],\n",
    "                   s=BASE_SIZE, alpha=BASE_ALPHA,\n",
    "                   linewidths=BASE_EDGE_LW, edgecolors=\"white\",\n",
    "                   color=pal[t], label=t)\n",
    "\n",
    "    # 高亮每个type的Top-1：黑边+更大\n",
    "    for t in MBTI_16:\n",
    "        if t not in top_idx_per_type: \n",
    "            continue\n",
    "        i = top_idx_per_type[t]\n",
    "        ax.scatter(Z[i,0], Z[i,1],\n",
    "                   s=HL_SIZE, alpha=1.0,\n",
    "                   linewidths=HL_EDGE_LW, edgecolors=\"black\",\n",
    "                   color=pal.get(t, \"k\"), zorder=5)\n",
    "        lbl = LABEL_FMT.format(cid=CID[i])\n",
    "        kw = dict(fontsize=LABEL_FONTSIZE, zorder=6)\n",
    "        if LABEL_WITH_BOX:\n",
    "            kw[\"bbox\"] = dict(facecolor=\"white\", alpha=0.9, lw=0, pad=0.5)\n",
    "        ax.text(Z[i,0], Z[i,1], lbl, **kw)\n",
    "\n",
    "    ax.set_xlabel(\"t-SNE-1\")\n",
    "    ax.set_ylabel(\"t-SNE-2\")\n",
    "    ax.set_title(\"FEM Personality Embedding Space (t-SNE)\\nTop-1 by hits per Type (ID labels)\", fontsize=12)\n",
    "\n",
    "    # 图例（只显示type，不再显示 n）\n",
    "    lg = ax.legend(frameon=False, bbox_to_anchor=(1.02, 1), loc=\"upper left\", borderaxespad=0.)\n",
    "    for txt in lg.get_texts(): txt.set_fontsize(9)\n",
    "\n",
    "    # 去掉上/右边框\n",
    "    ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(OUT_PDF,  bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_SVG,  bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_PNG,  dpi=600, bbox_inches=\"tight\")\n",
    "    print(f\"✅ 保存：{OUT_PDF} / {OUT_SVG} / {OUT_PNG}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cd7e4ad8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[t-SNE] Computing 91 nearest neighbors...\n",
      "[t-SNE] Indexed 27223 samples in 0.001s...\n",
      "[t-SNE] Computed neighbors for 27223 samples in 0.426s...\n",
      "[t-SNE] Computed conditional probabilities for sample 1000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 2000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 3000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 4000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 5000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 6000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 7000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 8000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 9000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 10000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 11000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 12000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 13000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 14000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 15000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 16000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 17000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 18000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 19000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 20000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 21000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 22000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 23000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 24000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 25000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 26000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27223 / 27223\n",
      "[t-SNE] Mean sigma: 0.072341\n",
      "[t-SNE] KL divergence after 250 iterations with early exaggeration: 81.424835\n",
      "[t-SNE] KL divergence after 1500 iterations: 1.738065\n",
      "✅ 保存：types_tsne_starburst_hitsonly.pdf / types_tsne_starburst_hitsonly.svg / types_tsne_starburst_hitsonly.png\n"
     ]
    }
   ],
   "source": [
    "# plot_types_tsne_starburst_top1_hits_only.py\n",
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "t-SNE（天女散花）可视化：\n",
    "- 读取 case_meta_with_hits.json（若不存在则自动由 counts+embeddings+train 合并生成）\n",
    "- 16个MBTI类型分别着色（颜色更鲜明），底层点较大且不透明度高一些\n",
    "- 每个type仅高亮命中最多的1个样本（黑色细边），标签显示“命中次数 hits”\n",
    "- 右侧增加脚注图示：说明 label = hits\n",
    "- 导出 PDF/SVG/600dpi PNG\n",
    "\"\"\"\n",
    "\n",
    "import os, json, csv, re, random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# ====== 路径与字段 ======\n",
    "MERGED_JSON = \"case_meta_with_hits.json\"                 # 优先读取\n",
    "COUNTS_CSV  = \"case_usage_all.csv\"                       # 备选合并来源\n",
    "EMB_JSON    = \"casebank_A_train_80_with_embeddings.json\" # 备选合并来源\n",
    "TRAIN_JSON  = \"train.json\"                               # 备选合并来源\n",
    "TEXT_FIELD  = \"posts_cleaned\"\n",
    "\n",
    "# ====== t-SNE 参数（决定“天女散花”效果）======\n",
    "SEED = 0\n",
    "TSNE_PERPLEXITY = 30\n",
    "TSNE_EE         = 12\n",
    "TSNE_LR         = 200\n",
    "TSNE_ITER       = 1500\n",
    "PCA_DIM         = 50\n",
    "\n",
    "# ====== 绘图样式（更鲜明） ======\n",
    "MBTI_16 = [\n",
    "    \"ENFJ\",\"ENFP\",\"ENTJ\",\"ENTP\",\"ESFJ\",\"ESFP\",\"ESTJ\",\"ESTP\",\n",
    "    \"INFJ\",\"INFP\",\"INTJ\",\"INTP\",\"ISFJ\",\"ISFP\",\"ISTJ\",\"ISTP\"\n",
    "]\n",
    "BASE_SIZE   = 16\n",
    "BASE_ALPHA  = 0.85\n",
    "BASE_EDGE_LW = 0.2            # 白色细边让群簇边界更清晰\n",
    "\n",
    "HL_SIZE     = 140             # 高亮点大小\n",
    "HL_EDGE_LW  = 1.0             # 高亮黑边\n",
    "LABEL_FONTSIZE = 9\n",
    "LABEL_WITH_BOX = True         # 标签白底\n",
    "\n",
    "# 想避免出现“0”，把它设为 1；保持 0 则也会显示 0 命中（当该类都未命中时）\n",
    "MIN_HITS_FOR_HIGHLIGHT = 0\n",
    "\n",
    "OUT_PDF = \"types_tsne_starburst_hitsonly.pdf\"\n",
    "OUT_SVG = \"types_tsne_starburst_hitsonly.svg\"\n",
    "OUT_PNG = \"types_tsne_starburst_hitsonly.png\"\n",
    "\n",
    "# ====== 小工具 ======\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def key_prefix80_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "\n",
    "def set_seeds(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "\n",
    "def vivid_palette_16():\n",
    "    # tab10 主色 + tab20 里对比度高的色，拼成 16 色\n",
    "    tab10 = list(plt.get_cmap(\"tab10\").colors)\n",
    "    tab20 = list(plt.get_cmap(\"tab20\").colors)\n",
    "    extras = [tab20[i] for i in [1,3,5,7,9,11]]\n",
    "    cols = tab10 + extras\n",
    "    return {t: cols[i] for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "def build_merged_from_sources():\n",
    "    with open(TRAIN_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        train = json.load(f)\n",
    "    id2type, id2text, id2key = {}, {}, {}\n",
    "    for i, it in enumerate(train):\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        id2type[i] = it.get(\"type\",\"\")\n",
    "        id2text[i] = t\n",
    "        id2key[i]  = key_prefix80_space(t)\n",
    "\n",
    "    with open(EMB_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        emb_rows = json.load(f)\n",
    "    key2emb = {}\n",
    "    for it in emb_rows:\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        if isinstance(t, str) and t and \"embedding\" in it:\n",
    "            key2emb[key_prefix80_space(t)] = it[\"embedding\"]\n",
    "\n",
    "    hits = {}\n",
    "    with open(COUNTS_CSV, \"r\", encoding=\"utf-8\") as f:\n",
    "        r = csv.DictReader(f)\n",
    "        for row in r:\n",
    "            hits[int(row[\"case_id\"])] = int(row[\"hits\"])\n",
    "\n",
    "    merged, miss = [], 0\n",
    "    for cid in range(len(train)):\n",
    "        emb = key2emb.get(id2key.get(cid,\"\"))\n",
    "        if emb is None:\n",
    "            miss += 1\n",
    "            continue\n",
    "        merged.append({\n",
    "            \"case_id\": cid,\n",
    "            \"type\": id2type.get(cid,\"\"),\n",
    "            \"text\": id2text.get(cid,\"\"),\n",
    "            \"hits\": hits.get(cid, 0),\n",
    "            \"embedding\": emb\n",
    "        })\n",
    "    with open(MERGED_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(merged, f, ensure_ascii=False, indent=2)\n",
    "    print(f\"[merge] 构建 merged：{len(merged)} 条（缺失向量 {miss}）→ {MERGED_JSON}\")\n",
    "    return merged\n",
    "\n",
    "def load_merged():\n",
    "    if os.path.exists(MERGED_JSON):\n",
    "        with open(MERGED_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "            return json.load(f)\n",
    "    print(\"[load] 未找到 merged，自动从三源合并…\")\n",
    "    return build_merged_from_sources()\n",
    "\n",
    "# ====== 主流程 ======\n",
    "def main():\n",
    "    set_seeds(SEED)\n",
    "    merged = load_merged()\n",
    "    assert merged, \"没有可用数据\"\n",
    "\n",
    "    X   = np.array([m[\"embedding\"] for m in merged], dtype=np.float32)\n",
    "    H   = np.array([m[\"hits\"] for m in merged], dtype=np.int32)\n",
    "    T   = np.array([m.get(\"type\",\"\") for m in merged])\n",
    "\n",
    "    # PCA -> t-SNE\n",
    "    X50 = PCA(n_components=min(PCA_DIM, max(2, X.shape[1]-1)), random_state=SEED).fit_transform(X)\n",
    "    Z = TSNE(n_components=2, perplexity=TSNE_PERPLEXITY, early_exaggeration=TSNE_EE,\n",
    "             learning_rate=TSNE_LR, n_iter=TSNE_ITER, init=\"pca\",\n",
    "             random_state=SEED, angle=0.5, verbose=1).fit_transform(X50)\n",
    "\n",
    "    # 每个type找“命中最多”的那个；若并列取第一个；可选过滤 min hits\n",
    "    top_idx_per_type = {}\n",
    "    for typ in np.unique(T):\n",
    "        mask = (T == typ)\n",
    "        if not mask.any():\n",
    "            continue\n",
    "        idxs_all = np.where(mask)[0]\n",
    "        idxs_pos = idxs_all[H[idxs_all] >= MIN_HITS_FOR_HIGHLIGHT]\n",
    "        if idxs_pos.size == 0:\n",
    "            # 该类没有达到阈值的命中样本，则不高亮\n",
    "            continue\n",
    "        best = idxs_pos[np.argmax(H[idxs_pos])]\n",
    "        top_idx_per_type[typ] = best\n",
    "\n",
    "    pal = vivid_palette_16()\n",
    "\n",
    "    # 绘图\n",
    "    fig = plt.figure(figsize=(7.5, 5.5), dpi=300)\n",
    "    ax = plt.gca()\n",
    "\n",
    "    # 背景点（所有样本）\n",
    "    type_handles = []\n",
    "    for t in MBTI_16:\n",
    "        mt = (T == t)\n",
    "        if not mt.any(): \n",
    "            continue\n",
    "        sc = ax.scatter(Z[mt,0], Z[mt,1],\n",
    "                        s=BASE_SIZE, alpha=BASE_ALPHA,\n",
    "                        linewidths=BASE_EDGE_LW, edgecolors=\"white\",\n",
    "                        color=pal[t], label=t)\n",
    "        type_handles.append(sc)\n",
    "\n",
    "    # 高亮每个type的Top-1：标签=hits\n",
    "    for t in MBTI_16:\n",
    "        if t not in top_idx_per_type:\n",
    "            continue\n",
    "        i = top_idx_per_type[t]\n",
    "        ax.scatter(Z[i,0], Z[i,1],\n",
    "                   s=HL_SIZE, alpha=1.0,\n",
    "                   linewidths=HL_EDGE_LW, edgecolors=\"black\",\n",
    "                   color=pal.get(t, \"k\"), zorder=5)\n",
    "        lbl = f\"{int(H[i])}\"          # ← 只显示命中次数\n",
    "        kw  = dict(fontsize=LABEL_FONTSIZE, zorder=6)\n",
    "        if LABEL_WITH_BOX:\n",
    "            kw[\"bbox\"] = dict(facecolor=\"white\", alpha=0.9, lw=0, pad=0.5)\n",
    "        ax.text(Z[i,0], Z[i,1], lbl, **kw)\n",
    "\n",
    "    ax.set_xlabel(\"t-SNE-1\"); ax.set_ylabel(\"t-SNE-2\")\n",
    "    ax.set_title(\"FEM Personality Embedding Space (t-SNE)\\nTop-1 by hits per Type (label = hits)\", fontsize=12)\n",
    "\n",
    "    # 图例：类型颜色\n",
    "    leg1 = ax.legend(frameon=False, bbox_to_anchor=(1.02, 1), loc=\"upper left\", borderaxespad=0., title=\"Types\")\n",
    "    for txt in leg1.get_texts(): txt.set_fontsize(9)\n",
    "    ax.add_artist(leg1)\n",
    "\n",
    "    # 右侧脚注：显示规则（样本 / Top-1，label=hits）\n",
    "    sample_handle = Line2D([0],[0], marker='o', color='w',\n",
    "                           markerfacecolor='0.6', markeredgecolor='white',\n",
    "                           markeredgewidth=BASE_EDGE_LW, markersize=6, label='sample')\n",
    "    top1_handle   = Line2D([0],[0], marker='o', color='black',\n",
    "                           markerfacecolor='tab:orange', markeredgewidth=HL_EDGE_LW,\n",
    "                           markersize=8, label='Top-1 (label = hits)')\n",
    "    leg2 = ax.legend(handles=[sample_handle, top1_handle],\n",
    "                     frameon=False, bbox_to_anchor=(1.02, 0.0), loc=\"lower left\",\n",
    "                     title=\"Display guide\")\n",
    "    for txt in leg2.get_texts(): txt.set_fontsize(9)\n",
    "\n",
    "    ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_SVG, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_PNG, dpi=600, bbox_inches=\"tight\")\n",
    "    print(f\"✅ 保存：{OUT_PDF} / {OUT_SVG} / {OUT_PNG}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0b11a7de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[load] 27223 条 ← case_meta_with_hits.json\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[t-SNE] Computing 91 nearest neighbors...\n",
      "[t-SNE] Indexed 27223 samples in 0.001s...\n",
      "[t-SNE] Computed neighbors for 27223 samples in 0.407s...\n",
      "[t-SNE] Computed conditional probabilities for sample 1000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 2000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 3000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 4000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 5000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 6000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 7000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 8000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 9000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 10000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 11000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 12000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 13000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 14000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 15000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 16000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 17000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 18000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 19000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 20000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 21000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 22000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 23000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 24000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 25000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 26000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27223 / 27223\n",
      "[t-SNE] Mean sigma: 0.072341\n",
      "[t-SNE] KL divergence after 250 iterations with early exaggeration: 81.424835\n",
      "[t-SNE] KL divergence after 1500 iterations: 1.738065\n",
      "✅ 保存：types_tsne_starburst_topk_hits_ring.pdf / types_tsne_starburst_topk_hits_ring.svg / types_tsne_starburst_topk_hits_ring.png   （标签自动避让：×）\n"
     ]
    }
   ],
   "source": [
    "# plot_types_tsne_starburst_topk_hits_ring.py\n",
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "t-SNE（天女散花）可视化：\n",
    "- 读取 case_meta_with_hits.json（若无则由 counts+embeddings+train 自动合并生成）\n",
    "- 16 个 MBTI 类型着色（鲜明）\n",
    "- 每个 type 高亮命中最多的 TOPK 个样本：彩色实心点 + 黑色细边的小圆点；标签仅显示“命中次数”\n",
    "- 右侧双图例：类型颜色 & 形状说明（Top-K，label=hits）\n",
    "- 导出 PDF/SVG/600dpi PNG\n",
    "\"\"\"\n",
    "\n",
    "import os, json, csv, re, random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# ========= 路径与字段 =========\n",
    "MERGED_JSON = \"case_meta_with_hits.json\"                 # 优先读取\n",
    "COUNTS_CSV  = \"case_usage_all.csv\"                       # 备选合并来源\n",
    "EMB_JSON    = \"casebank_A_train_80_with_embeddings.json\" # 备选合并来源\n",
    "TRAIN_JSON  = \"train.json\"                               # 备选合并来源\n",
    "TEXT_FIELD  = \"posts_cleaned\"\n",
    "\n",
    "# ========= t-SNE 超参（决定“天女散花”感）=========\n",
    "SEED = 0\n",
    "TSNE_PERPLEXITY = 30\n",
    "TSNE_EE         = 12\n",
    "TSNE_LR         = 200\n",
    "TSNE_ITER       = 1500\n",
    "PCA_DIM         = 50\n",
    "\n",
    "# ========= 绘图样式（鲜明）=========\n",
    "MBTI_16 = [\n",
    "    \"ENFJ\",\"ENFP\",\"ENTJ\",\"ENTP\",\"ESFJ\",\"ESFP\",\"ESTJ\",\"ESTP\",\n",
    "    \"INFJ\",\"INFP\",\"INTJ\",\"INTP\",\"ISFJ\",\"ISFP\",\"ISTJ\",\"ISTP\"\n",
    "]\n",
    "BASE_SIZE    = 16\n",
    "BASE_ALPHA   = 0.85\n",
    "BASE_EDGE_LW = 0.2          # 白色细边让簇边界更清晰\n",
    "\n",
    "TOPK_PER_TYPE = 1         # ★ 每类高亮几个\n",
    "HL_RING_SIZE  = 80          # 高亮点尺寸（小圆点）\n",
    "HL_EDGE_LW    = 0.9         # 黑色细边\n",
    "LABEL_FONTSZ  = 9           # 标签字号\n",
    "LABEL_BOX     = True        # 标签白底\n",
    "USE_ADJUSTTEXT = True       # 若环境有 adjustText，则自动避让；否则用小偏移\n",
    "\n",
    "OUT_PDF = \"types_tsne_starburst_topk_hits_ring.pdf\"\n",
    "OUT_SVG = \"types_tsne_starburst_topk_hits_ring.svg\"\n",
    "OUT_PNG = \"types_tsne_starburst_topk_hits_ring.png\"\n",
    "\n",
    "# ========= 工具 =========\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def key_prefix80_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "\n",
    "def set_seeds(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "\n",
    "def vivid_palette_16():\n",
    "    \"\"\"更鲜明的 16 色：tab10 + tab20 中饱和色\"\"\"\n",
    "    tab10 = list(plt.get_cmap(\"tab10\").colors)          # 10\n",
    "    tab20 = list(plt.get_cmap(\"tab20\").colors)          # 20\n",
    "    extras = [tab20[i] for i in [1,3,5,7,9,11]]          # 取 6 个对比度高的补足\n",
    "    cols = tab10 + extras\n",
    "    return {t: cols[i] for i, t in enumerate(MBTI_16)}\n",
    "\n",
    "def build_merged_from_sources():\n",
    "    # 1) train：id->type/text/key\n",
    "    with open(TRAIN_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        train = json.load(f)\n",
    "    id2type, id2text, id2key = {}, {}, {}\n",
    "    for i, it in enumerate(train):\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        id2type[i] = it.get(\"type\",\"\")\n",
    "        id2text[i] = t\n",
    "        id2key[i]  = key_prefix80_space(t)\n",
    "\n",
    "    # 2) embedding：key->emb\n",
    "    with open(EMB_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        emb_rows = json.load(f)\n",
    "    key2emb = {}\n",
    "    for it in emb_rows:\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        if isinstance(t, str) and t and \"embedding\" in it:\n",
    "            key2emb[key_prefix80_space(t)] = it[\"embedding\"]\n",
    "\n",
    "    # 3) hits：cid->hits\n",
    "    hits = {}\n",
    "    with open(COUNTS_CSV, \"r\", encoding=\"utf-8\") as f:\n",
    "        r = csv.DictReader(f)\n",
    "        for row in r:\n",
    "            hits[int(row[\"case_id\"])] = int(row[\"hits\"])\n",
    "\n",
    "    # 4) 合并\n",
    "    merged, miss = [], 0\n",
    "    for cid in range(len(train)):\n",
    "        emb = key2emb.get(id2key.get(cid,\"\"))\n",
    "        if emb is None:\n",
    "            miss += 1\n",
    "            continue\n",
    "        merged.append({\n",
    "            \"case_id\": cid,\n",
    "            \"type\": id2type.get(cid,\"\"),\n",
    "            \"text\": id2text.get(cid,\"\"),\n",
    "            \"hits\": hits.get(cid, 0),\n",
    "            \"embedding\": emb\n",
    "        })\n",
    "    with open(MERGED_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(merged, f, ensure_ascii=False, indent=2)\n",
    "    print(f\"[merge] 构建 merged：{len(merged)} 条（缺失向量 {miss}）→ {MERGED_JSON}\")\n",
    "    return merged\n",
    "\n",
    "def load_merged():\n",
    "    if os.path.exists(MERGED_JSON):\n",
    "        with open(MERGED_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "            merged = json.load(f)\n",
    "        print(f\"[load] {len(merged)} 条 ← {MERGED_JSON}\")\n",
    "        return merged\n",
    "    print(\"[load] 未找到 merged，自动从三源合并…\")\n",
    "    return build_merged_from_sources()\n",
    "\n",
    "# ========= 主流程 =========\n",
    "def main():\n",
    "    set_seeds(SEED)\n",
    "    merged = load_merged()\n",
    "    assert merged, \"没有可用数据\"\n",
    "\n",
    "    X = np.array([m[\"embedding\"] for m in merged], dtype=np.float32)\n",
    "    H = np.array([m[\"hits\"] for m in merged], dtype=np.int32)\n",
    "    T = np.array([m.get(\"type\",\"\") for m in merged])\n",
    "\n",
    "    # PCA -> t-SNE（稳定出“放射/花朵”）\n",
    "    X50 = PCA(n_components=min(PCA_DIM, max(2, X.shape[1]-1)), random_state=SEED).fit_transform(X)\n",
    "    Z = TSNE(\n",
    "        n_components=2, perplexity=TSNE_PERPLEXITY,\n",
    "        early_exaggeration=TSNE_EE, learning_rate=TSNE_LR,\n",
    "        n_iter=TSNE_ITER, init=\"pca\", random_state=SEED, angle=0.5, verbose=1\n",
    "    ).fit_transform(X50)\n",
    "\n",
    "    # 每个 type 取命中最多的 TOPK\n",
    "    top_idx_per_type = {}\n",
    "    for typ in np.unique(T):\n",
    "        mt = (T == typ)\n",
    "        if not mt.any(): \n",
    "            continue\n",
    "        idx = np.where(mt)[0]\n",
    "        order = idx[np.argsort(-H[idx])]\n",
    "        top_idx_per_type[typ] = order[:min(TOPK_PER_TYPE, len(order))]\n",
    "\n",
    "    pal = vivid_palette_16()\n",
    "\n",
    "    fig = plt.figure(figsize=(7.5, 5.5), dpi=300)\n",
    "    ax = plt.gca()\n",
    "\n",
    "    # 先画“底层”所有点\n",
    "    legends_type = []\n",
    "    for t in MBTI_16:\n",
    "        mt = (T == t)\n",
    "        if not mt.any(): continue\n",
    "        ax.scatter(\n",
    "            Z[mt,0], Z[mt,1],\n",
    "            s=BASE_SIZE, alpha=BASE_ALPHA,\n",
    "            linewidths=BASE_EDGE_LW, edgecolors=\"white\",\n",
    "            color=pal[t]\n",
    "        )\n",
    "        legends_type.append(Line2D([0],[0], marker='o', color='w',\n",
    "                                   markerfacecolor=pal[t], markeredgecolor='white',\n",
    "                                   markeredgewidth=BASE_EDGE_LW, markersize=7, label=t))\n",
    "\n",
    "    # 高亮 Top-K：彩色实心点 + 黑色细边；标签=命中次数\n",
    "    texts = []\n",
    "    use_adjust = False\n",
    "    for t in MBTI_16:\n",
    "        if t not in top_idx_per_type: \n",
    "            continue\n",
    "        for i in top_idx_per_type[t]:\n",
    "            x, y = Z[i,0], Z[i,1]\n",
    "            ax.scatter(x, y,\n",
    "                       s=HL_RING_SIZE, alpha=1.0,\n",
    "                       linewidths=HL_EDGE_LW, edgecolors=\"black\",\n",
    "                       color=pal.get(t, \"k\"), zorder=6)\n",
    "            # 仅 hits\n",
    "            txt = f\"{int(H[i])}\"\n",
    "            if USE_ADJUSTTEXT:\n",
    "                texts.append(ax.text(x, y, txt,\n",
    "                                     fontsize=LABEL_FONTSZ, zorder=7,\n",
    "                                     bbox=dict(facecolor=\"white\", alpha=0.95, lw=0, pad=0.5) if LABEL_BOX else None))\n",
    "            else:\n",
    "                # 简单偏移防遮挡\n",
    "                ax.annotate(txt, xy=(x, y), xycoords='data',\n",
    "                            xytext=(10, 6), textcoords='offset points',\n",
    "                            fontsize=LABEL_FONTSZ, zorder=7,\n",
    "                            bbox=dict(facecolor=\"white\", alpha=0.95, lw=0, pad=0.5) if LABEL_BOX else None)\n",
    "\n",
    "    # 调整标签避免遮挡（若可用）\n",
    "    if USE_ADJUSTTEXT and texts:\n",
    "        try:\n",
    "            from adjustText import adjust_text\n",
    "            adjust_text(texts,\n",
    "                        only_move={'points': 'y', 'text': 'xy'},\n",
    "                        expand_points=(1.2, 1.2), expand_text=(1.2, 1.2))\n",
    "            use_adjust = True\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "    # ===== 图例：类型颜色 =====\n",
    "    leg1 = ax.legend(handles=legends_type, frameon=False,\n",
    "                     bbox_to_anchor=(1.02, 1), loc=\"upper left\", borderaxespad=0., title=\"Types\")\n",
    "    for txt in leg1.get_texts(): txt.set_fontsize(9)\n",
    "    ax.add_artist(leg1)\n",
    "\n",
    "    # ===== 图例：形状说明 =====\n",
    "    sample_handle = Line2D([0],[0], marker='o', color='w',\n",
    "                           markerfacecolor='0.5', markeredgecolor='white',\n",
    "                           markeredgewidth=BASE_EDGE_LW, markersize=6, label='sample')\n",
    "    topk_handle = Line2D([0],[0], marker='o', color='black',\n",
    "                         markerfacecolor='tab:orange', markeredgewidth=HL_EDGE_LW,\n",
    "                         markersize=8, label=f'Top-{TOPK_PER_TYPE} (label = hits)')\n",
    "    leg2 = ax.legend(handles=[sample_handle, topk_handle],\n",
    "                     frameon=False, bbox_to_anchor=(1.02, 0.0), loc=\"lower left\",\n",
    "                     title=\"Display guide\")\n",
    "    for txt in leg2.get_texts(): txt.set_fontsize(9)\n",
    "\n",
    "    ax.set_xlabel(\"t-SNE-1\"); ax.set_ylabel(\"t-SNE-2\")\n",
    "    ax.set_title(f\"FEM Personality Embedding Space (t-SNE)\\nTop-{TOPK_PER_TYPE} per Type (label shows hits)\", fontsize=12)\n",
    "    ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_SVG, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_PNG, dpi=600, bbox_inches=\"tight\")\n",
    "    print(f\"✅ 保存：{OUT_PDF} / {OUT_SVG} / {OUT_PNG}   （标签自动避让：{'√' if use_adjust else '×'}）\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "eadee089",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/hli962/miniconda3/lib/python3.12/site-packages/sklearn/manifold/_t_sne.py:1164: FutureWarning: 'n_iter' was renamed to 'max_iter' in version 1.5 and will be removed in 1.7.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[t-SNE] Computing 91 nearest neighbors...\n",
      "[t-SNE] Indexed 27223 samples in 0.001s...\n",
      "[t-SNE] Computed neighbors for 27223 samples in 0.426s...\n",
      "[t-SNE] Computed conditional probabilities for sample 1000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 2000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 3000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 4000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 5000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 6000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 7000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 8000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 9000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 10000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 11000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 12000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 13000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 14000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 15000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 16000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 17000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 18000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 19000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 20000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 21000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 22000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 23000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 24000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 25000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 26000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27000 / 27223\n",
      "[t-SNE] Computed conditional probabilities for sample 27223 / 27223\n",
      "[t-SNE] Mean sigma: 0.072341\n",
      "[t-SNE] KL divergence after 250 iterations with early exaggeration: 81.424820\n",
      "[t-SNE] KL divergence after 1500 iterations: 1.738065\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2529676/1807329542.py:202: UserWarning: Glyph 32500 (\\N{CJK UNIFIED IDEOGRAPH-7EF4}) missing from font(s) DejaVu Sans.\n",
      "  plt.tight_layout(rect=[0,0,1,0.95])\n",
      "/tmp/ipykernel_2529676/1807329542.py:202: UserWarning: Glyph 24230 (\\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from font(s) DejaVu Sans.\n",
      "  plt.tight_layout(rect=[0,0,1,0.95])\n",
      "/tmp/ipykernel_2529676/1807329542.py:203: UserWarning: Glyph 32500 (\\N{CJK UNIFIED IDEOGRAPH-7EF4}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:203: UserWarning: Glyph 24230 (\\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:203: UserWarning: Glyph 32500 (\\N{CJK UNIFIED IDEOGRAPH-7EF4}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:203: UserWarning: Glyph 24230 (\\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:204: UserWarning: Glyph 32500 (\\N{CJK UNIFIED IDEOGRAPH-7EF4}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_SVG, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:204: UserWarning: Glyph 24230 (\\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_SVG, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:205: UserWarning: Glyph 32500 (\\N{CJK UNIFIED IDEOGRAPH-7EF4}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PNG, dpi=600, bbox_inches=\"tight\")\n",
      "/tmp/ipykernel_2529676/1807329542.py:205: UserWarning: Glyph 24230 (\\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from font(s) DejaVu Sans.\n",
      "  plt.savefig(OUT_PNG, dpi=600, bbox_inches=\"tight\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ 保存：mbti_dims_tsne_top3.pdf / mbti_dims_tsne_top3.svg / mbti_dims_tsne_top3.png\n"
     ]
    }
   ],
   "source": [
    "# plot_mbti_4dims_tsne_top3.py\n",
    "# -*- coding: utf-8 -*-\n",
    "\"\"\"\n",
    "四维度(E/I, S/N, T/F, J/P)可视化：\n",
    "- 读取 case_meta_with_hits.json（如无则由 counts+embeddings+train 自动合并）\n",
    "- 统一用 PCA->t-SNE 得到 2D 坐标\n",
    "- 2x2 子图：分别画 E/I、S/N、T/F、J/P\n",
    "- 每个子图：两侧各高亮 Top-3（hits 降序），小圆点+黑边，标签显示命中次数\n",
    "- 导出 PDF/SVG/600dpi PNG\n",
    "\"\"\"\n",
    "\n",
    "import os, json, csv, re, random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "matplotlib.use(\"Agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# ---------- 路径 ----------\n",
    "MERGED_JSON = \"case_meta_with_hits.json\"\n",
    "COUNTS_CSV  = \"case_usage_all.csv\"\n",
    "EMB_JSON    = \"casebank_A_train_80_with_embeddings.json\"\n",
    "TRAIN_JSON  = \"train.json\"\n",
    "TEXT_FIELD  = \"posts_cleaned\"\n",
    "\n",
    "# ---------- t-SNE 参数 ----------\n",
    "SEED = 0\n",
    "PCA_DIM = 50\n",
    "TSNE_PERPLEXITY = 30\n",
    "TSNE_EE = 12\n",
    "TSNE_LR = 200\n",
    "TSNE_ITER = 1500\n",
    "\n",
    "# ---------- 绘图样式 ----------\n",
    "BASE_SIZE = 12\n",
    "BASE_ALPHA = 0.25\n",
    "BASE_EDGE_LW = 0.0\n",
    "\n",
    "HL_SIZE = 80            # 高亮圆点大小\n",
    "HL_EDGE_LW = 0.9        # 黑色细边\n",
    "LABEL_FONTSZ = 9\n",
    "LABEL_BOX = True\n",
    "\n",
    "TOPK_PER_SIDE = 3       # 每侧 Top-K\n",
    "MIN_HITS_FOR_HL = 1     # 只高亮 hits >= 1，避免 0\n",
    "\n",
    "OUT_PDF = \"mbti_dims_tsne_top3.pdf\"\n",
    "OUT_SVG = \"mbti_dims_tsne_top3.svg\"\n",
    "OUT_PNG = \"mbti_dims_tsne_top3.png\"\n",
    "\n",
    "# ---------- 工具 ----------\n",
    "WS = re.compile(r\"\\s+\")\n",
    "def key_prefix80_space(s: str) -> str:\n",
    "    return WS.sub(\" \", (s or \"\").strip())[:80]\n",
    "\n",
    "def set_seeds(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "\n",
    "def vivid_pair_colors():\n",
    "    \"\"\"给四个维度各配两种对比色（可自定义）。\"\"\"\n",
    "    tab10 = list(plt.get_cmap(\"tab10\").colors)\n",
    "    pairs = {\n",
    "        \"EI\": (tab10[0], tab10[1]),   # E, I\n",
    "        \"SN\": (tab10[2], tab10[3]),   # S, N\n",
    "        \"TF\": (tab10[4], tab10[5]),   # T, F\n",
    "        \"JP\": (tab10[6], tab10[7]),   # J, P\n",
    "    }\n",
    "    return pairs\n",
    "\n",
    "def build_merged_from_sources():\n",
    "    # train\n",
    "    with open(TRAIN_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        train = json.load(f)\n",
    "    id2type, id2text, id2key = {}, {}, {}\n",
    "    for i, it in enumerate(train):\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        id2type[i] = it.get(\"type\",\"\")\n",
    "        id2text[i] = t\n",
    "        id2key[i]  = key_prefix80_space(t)\n",
    "    # embeddings\n",
    "    with open(EMB_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "        emb_rows = json.load(f)\n",
    "    key2emb = {}\n",
    "    for it in emb_rows:\n",
    "        t = it.get(TEXT_FIELD, \"\") or it.get(\"posts\",\"\")\n",
    "        if isinstance(t, str) and t and \"embedding\" in it:\n",
    "            key2emb[key_prefix80_space(t)] = it[\"embedding\"]\n",
    "    # hits\n",
    "    hits = {}\n",
    "    with open(COUNTS_CSV, \"r\", encoding=\"utf-8\") as f:\n",
    "        r = csv.DictReader(f)\n",
    "        for row in r:\n",
    "            hits[int(row[\"case_id\"])] = int(row[\"hits\"])\n",
    "    # merge\n",
    "    out, miss = [], 0\n",
    "    for cid in range(len(train)):\n",
    "        emb = key2emb.get(id2key.get(cid,\"\"))\n",
    "        if emb is None:\n",
    "            miss += 1; continue\n",
    "        out.append({\n",
    "            \"case_id\": cid,\n",
    "            \"type\": id2type.get(cid,\"\"),\n",
    "            \"text\": id2text.get(cid,\"\"),\n",
    "            \"hits\": hits.get(cid, 0),\n",
    "            \"embedding\": emb\n",
    "        })\n",
    "    with open(MERGED_JSON, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(out, f, ensure_ascii=False, indent=2)\n",
    "    print(f\"[merge] merged={len(out)}, miss_emb={miss} → {MERGED_JSON}\")\n",
    "    return out\n",
    "\n",
    "def load_merged():\n",
    "    if os.path.exists(MERGED_JSON):\n",
    "        with open(MERGED_JSON, \"r\", encoding=\"utf-8\") as f:\n",
    "            return json.load(f)\n",
    "    print(\"[load] not found merged, building…\")\n",
    "    return build_merged_from_sources()\n",
    "\n",
    "def type_to_4d(t: str):\n",
    "    t = (t or \"\").upper()\n",
    "    # 返回四个维度字符（E/I, S/N, T/F, J/P）\n",
    "    return t[0], t[1], t[2], t[3]\n",
    "\n",
    "# ---------- 主程序 ----------\n",
    "def main():\n",
    "    set_seeds(SEED)\n",
    "    data = load_merged()\n",
    "    assert data, \"没有数据\"\n",
    "\n",
    "    X = np.array([d[\"embedding\"] for d in data], dtype=np.float32)\n",
    "    H = np.array([d[\"hits\"]       for d in data], dtype=np.int32)\n",
    "    T = np.array([d[\"type\"]       for d in data])\n",
    "\n",
    "    # PCA -> t-SNE（一次坐标，四图共用）\n",
    "    X50 = PCA(n_components=min(PCA_DIM, max(2, X.shape[1]-1)), random_state=SEED).fit_transform(X)\n",
    "    Z = TSNE(n_components=2, perplexity=TSNE_PERPLEXITY, early_exaggeration=TSNE_EE,\n",
    "             learning_rate=TSNE_LR, n_iter=TSNE_ITER, init=\"pca\",\n",
    "             random_state=SEED, angle=0.5, verbose=1).fit_transform(X50)\n",
    "\n",
    "    # 解析四维字母\n",
    "    letters = np.array([type_to_4d(t) for t in T])  # shape (N,4)\n",
    "    EI = letters[:,0]   # 'E' or 'I'\n",
    "    SN = letters[:,1]   # 'S' or 'N'\n",
    "    TF = letters[:,2]   # 'T' or 'F'\n",
    "    JP = letters[:,3]   # 'J' or 'P'\n",
    "\n",
    "    DIM_SPECS = [\n",
    "        (\"EI\", (\"E\",\"I\"), EI),\n",
    "        (\"SN\", (\"S\",\"N\"), SN),\n",
    "        (\"TF\", (\"T\",\"F\"), TF),\n",
    "        (\"JP\", (\"J\",\"P\"), JP),\n",
    "    ]\n",
    "    pairs = vivid_pair_colors()\n",
    "\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(10, 8), dpi=300)\n",
    "    axes = axes.ravel()\n",
    "\n",
    "    for ax, (dim_tag, (a_char, b_char), side_arr) in zip(axes, DIM_SPECS):\n",
    "        cA, cB = pairs[dim_tag]\n",
    "        # 背景：两侧淡色\n",
    "        mA = (side_arr == a_char)\n",
    "        mB = (side_arr == b_char)\n",
    "        ax.scatter(Z[mA,0], Z[mA,1], s=BASE_SIZE, alpha=BASE_ALPHA,\n",
    "                   linewidths=BASE_EDGE_LW, edgecolors=\"none\", color=cA)\n",
    "        ax.scatter(Z[mB,0], Z[mB,1], s=BASE_SIZE, alpha=BASE_ALPHA,\n",
    "                   linewidths=BASE_EDGE_LW, edgecolors=\"none\", color=cB)\n",
    "\n",
    "        # 各侧 Top-3（仅 hits≥1）\n",
    "        handles = []\n",
    "        for side_char, color in [(a_char, cA), (b_char, cB)]:\n",
    "            m = (side_arr == side_char)\n",
    "            idx_all = np.where(m)[0]\n",
    "            idx_pos = idx_all[H[idx_all] >= MIN_HITS_FOR_HL]\n",
    "            if idx_pos.size == 0:\n",
    "                continue\n",
    "            order = idx_pos[np.argsort(-H[idx_pos])]\n",
    "            top_idx = order[:min(TOPK_PER_SIDE, len(order))]\n",
    "            # 高亮小圆点 + 黑色细边 + 命中数标签\n",
    "            for i in top_idx:\n",
    "                x, y = Z[i,0], Z[i,1]\n",
    "                ax.scatter(x, y, s=HL_SIZE, color=color, alpha=1.0,\n",
    "                           linewidths=HL_EDGE_LW, edgecolors=\"black\", zorder=5)\n",
    "                lbl = f\"{int(H[i])}\"\n",
    "                kw = dict(fontsize=LABEL_FONTSZ, zorder=6)\n",
    "                if LABEL_BOX:\n",
    "                    kw[\"bbox\"] = dict(facecolor=\"white\", alpha=0.95, lw=0, pad=0.4)\n",
    "                ax.text(x, y, lbl, **kw)\n",
    "\n",
    "            # 图例把两侧加进去\n",
    "            handles.append(Line2D([0],[0], marker='o', color='black',\n",
    "                                  markerfacecolor=color, markeredgewidth=HL_EDGE_LW,\n",
    "                                  markersize=7, label=f\"{side_char} (Top-3, label=hits)\"))\n",
    "\n",
    "        ax.legend(handles=handles, frameon=False, loc=\"upper left\", fontsize=8)\n",
    "        ax.set_title(f\"{dim_tag} 维度\", fontsize=11)\n",
    "        ax.set_xlabel(\"t-SNE-1\"); ax.set_ylabel(\"t-SNE-2\")\n",
    "        ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)\n",
    "\n",
    "    plt.suptitle(\"FEM Personality Embedding Space (t-SNE)\\nTop-3 per side for each MBTI dimension (label=hits)\", fontsize=13)\n",
    "    plt.tight_layout(rect=[0,0,1,0.95])\n",
    "    plt.savefig(OUT_PDF, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_SVG, bbox_inches=\"tight\")\n",
    "    plt.savefig(OUT_PNG, dpi=600, bbox_inches=\"tight\")\n",
    "    print(f\"✅ 保存：{OUT_PDF} / {OUT_SVG} / {OUT_PNG}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "504cfee3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"type\": \"ISFJ\",\n",
      "  \"posts\": \"Hello everyone, and thank you so much for welcoming me here. It means a lot to find a place where I can share and learn alongside others who understand the quiet strength and careful kindness that often come with being an ISFJ. I’ve always found comfort in routines and in helping those around me, even when it means putting my own needs aside for a while. Sometimes it’s hard to speak up, especially when emotions run deep inside and I don’t want to burden anyone. But I’m learning that it’s okay to ask for support, too.  \\n\\nI’m currently studying veterinary medicine, which feels like a calling for me—it combines my love for caring with a practical path to make a difference. It’s a journey filled with late nights and lots of organization, but also moments of pure joy when I see the comfort a little animal can find through care and attention.  \\n\\nTo anyone else who sometimes feels overwhelmed by the weight of expectations or the chaos of the world, please know you’re not alone. It’s okay to step back, breathe, and take things one small step at a time. We ISFJs have a quiet resilience that doesn’t always show on the surface, but it’s there, steady and strong. And together, in this community, we can remind each other that we matter, that our efforts are seen, and that it’s perfectly fine to be gentle with ourselves along the way.  \\n\\nThank you for being here, and I look forward to growing with you all.\",\n",
      "  \"posts_cleaned\": \"hello everyone thank much welcoming mean lot find place share learn alongside others understand quiet strength careful kindness often come isfj always found comfort routine helping around even mean putting need aside sometimes hard speak especially emotion run deep inside want burden anyone learning okay ask support currently studying veterinary medicine feel like calling combine love caring practical path make difference journey filled late night lot organization also moment pure joy see comfort little animal find care attention anyone else sometimes feel overwhelmed weight expectation chaos world please know alone okay step back breathe take thing one small step time isfjs quiet resilience always show surface steady strong together community remind matter effort seen perfectly fine gentle along way thank look forward growing\",\n",
      "  \"semantic_view\": \"The post expresses gratitude for community support, shares personal reflections on being an ISFJ personality type, discusses challenges and growth in asking for help, and highlights the author's passion for veterinary medicine as a meaningful and fulfilling career path. It also offers reassurance to others feeling overwhelmed.\",\n",
      "  \"sentiment_view\": \"The emotions conveyed are warmth, gratitude, vulnerability, and hope. The author shows a gentle, caring attitude with moments of self-reflection and encouragement towards others experiencing similar feelings.\",\n",
      "  \"linguistic_view\": \"The writing style is informal, sincere, and emotionally expressive. It uses personal anecdotes and empathetic language to create a sense of connection and openness.\"\n",
      "}\n",
      "{\n",
      "  \"type\": \"INTP\",\n",
      "  \"posts\": \"I was indifferent to zombies before the likes of Warm Bodies and In The Flesh came along and humanised them a bit more. Then it got interesting. Perhaps because then I related more to the zombies...|||I was telling my INFx sister how when I was talking to my INFJ friend tonight he brought up the fact I 'enthusiastically' expressed interest in drawing comics last week when he told me he used to...|||Oh, we have discussed it. I dislike the idea of leading anybody on or people not knowing where they stand. We click in the sense we have a lot to talk about so I went on a date with him when he asked...|||I feel like all the INFJs I know want me to give them something I'm not capable of. For a certain amount of time I think there is an illusion that I'm meeting their needs but in reality we're just...|||lavendersnow  I find your reply really interesting and you touch a lot on something I've been thinking about a lot lately because this particular INFJ is interested in me romantically and I think...|||These are some really interesting answers and I definitely relate to a lot of them.   For me, when I care about something, I'm always approaching it from a 'how can I be productive' perspective,...|||I was talking to an INFJ about his grandfather. He has dementia and has just been taken into hospital after some falls. He was telling me about it and then apologised for talking about 'depressing...|||I have to say that my favourite people to chat to and the ones I have the most interesting and stimulating conversations with tend to be those that I can go even weeks without replying to. Sometimes...|||If I don't enjoy a person to some degree I probably won't respond. Which is a pretty clear message.  I do have to say there has been a lot of times when someone has told me something and I either...|||Obviously. Places like this are just a different medium. I use them differently. That is all.   I never said I was going to dismiss anyone based on their aversion to SM. I said I was mulling it...|||It's not that they never do but it's a different vibe to what I get from places like this. Though that could totally be in my head. When I talk in forums like this I'm more aware of others. My social...|||Ah. I like that social media can be more like shouting into the void than somewhere like this where people are more inclined to reply and then I have to interact. I can be less coherent and nobody is...|||I guess you weird me out then. :'D|||Oh, I should have specified I meant my generation.|||People that don't use social media weird me out. Where do they unload all the thoughts they don't want to tell any specific person?|||I don't accept that.|||Oh, don't be sorry at all. Venting is good. It gets a bad rap but it's actually very important. The reason I asked was because I wanted to know how best to advise you. Sometimes I can start of...|||The deciding if the date is worthy part.|||An INFJ has asked me out on a date and says he has some ideas for a stress-free date. I'll report back when he tells me what they are so we can decide if they are worthy. (Kidding)|||I'm like a unicorn. Touch my horn! :proud:|||TechFreak  Ah, so rather than advice you're really looking to vent? I understand how typing everything out can really be cathartic and help process things.   If I was to give advice I would say...|||Why am I irresistible...  :dry:|||All I get from this is that you both seem on completely different wavelengths. There's obviously some sort of issue with communication. It's almost like you're speaking different languages.   From...|||WELL, he just asked me out on a date. So I guess I did something right. Haha.   We've only been talking - on a dating website - for the last week or so but our conversations have been like essays....|||I ended up making a joke, apologising for the joke, telling them I'm the worst at taking compliments and saying thank you. Totally smooth. :cool: He seemed to take it well though. :blushed:|||I'm not gonna lie. I came here to see if I could work out if I like someone. :blushed:|||That does sound really interesting and fun! Visiting a stable in general would be something I'd love. Getting to pet and feed horses would be awesome.|||I've been thinking. I think I attract people with strong fe partly because I'm disabled. This is a new theory though, so I'm going to have to develop it.|||I think you have a valid point and I would find his disregard of a pretty standard trigger hard to get over. Maybe I would feel different if it was something very personal to him but I'd still be...|||I think in a way we shoot ourselves in the foot with it because we basically poke at someone's emotions and then have no idea - unless we've worked on it - how to deal with those emotions or help...|||I guess a big part of it is second guessing myself and worrying that they're feeling more than what I realise or not. So I'm not sure how much of a deal to make it of it. It doesn't help that I tend...|||I'm thinking about all the things I could have done today but didn't. All the things I could be doing every day but I'm not. And how I have little to no motivation to rectify the situation.|||Yeah... how do I break it to them that I don't have any? Haha.  That's how I feel too. It's not even that they won't continue to debate, they don't even start them. I mean, I know that I'm right A...|||https://c3.staticflickr.com/9/8051/29416174810_94c29cdcf2.jpg  It's a me. :cool:|||Maybe part of the issue with the particular INFJs I've been talking to it's under the guise of potential romance and to them, disagreeing with me doesn't make sense when you could potentially be...|||That makes a lot of sense and also probably why I'm also a little bit in awe of their intellect. I learn so much from them, even if I might not actually hold on to the information for very long. I...|||An INFJ just told me I was beautiful in a really sweet way (I don't usually find things sweet tbh) but I'm not great at taking compliments and I don't know how best to respond. My natural reaction is...|||https://www.youtube.com/watch?v=Rfo9VAQxmKI  <3|||I feel like people need to understand the difference between an emotional exchange and a factual/thinker exchange.   The two aren't always mutually exclusive but sometimes they are.   I can tell...|||Ah. In my local shelters you can't really volunteer like that unless you're properly on their books as a volunteer, if that makes sense. You can sometimes walk dogs though, so I guess there's that....|||I guess I would say it's not my favourite country. Just a country I want to visit most right now. Or somewhere I'd like to live for a bit.|||I took this yesterday when I was still under the impression I was an ISTP and low and behold I got ISTP. The results were clear, slight, moderate, clear... if I remember correctly. It wasn't a...|||If someone asked me to an animal shelter on a date I would ask them are they a masochist in all areas or...?|||When you're discussing how you lack the ability to form meaningful connections with the majority of people because you don't find the majority of people stimulating enough...   Me: The majority of...|||I'm Northern Irish, leaning towards British.   Canada! Home of ice hockey and generally just seems like a cool place, plus it's close enough to America without actually being America.|||Yeah, it's this for sure. My overall interest is in writing but what I enjoy writing about or certain topics that inspire me aren't constant. I guess that's where the issue comes in. Maybe I just...|||I think that makes sense and is probably why the blog posts I'm most motivated and inspired to write are usually the ones that benefit me in some way. Usually the ones where I get the warm fuzzies...|||I bet that's more common than you'd think.|||581906|||I'm still working on this. My room is full of unfinished projects that I may or may not get around to finishing some day. I also have a blog and find that I have spikes of motivation and inspiration...\",\n",
      "  \"posts_cleaned\": \"indifferent zombie like warm body flesh came along humanised bit got interesting perhaps related zombie telling infx sister talking infj friend tonight brought fact enthusiastically expressed interest drawing comic last week told used discussed dislike idea leading anybody people knowing stand click sense lot talk went date asked feel like infjs know want give something capable certain amount time think illusion meeting need reality lavendersnow find reply really interesting touch lot something thinking lot lately particular infj interested romantically think really interesting answer definitely relate lot care something always approaching productive perspective talking infj grandfather dementia taken hospital fall telling apologised talking depressing say favourite people chat one interesting stimulating conversation tend even week without replying sometimes enjoy person degree probably respond pretty clear message say lot time someone told something either obviously place like different medium use differently never said going dismiss anyone based aversion said mulling never different vibe get place like though could totally head talk forum like aware others social like social medium like shouting void somewhere like people inclined reply interact less coherent nobody guess weird specified meant generation people use social medium weird unload thought want tell specific person accept sorry venting good get bad rap actually important reason asked wanted know best advise sometimes start deciding date worthy part infj asked date say idea stress free date report back tell decide worthy kidding like unicorn touch horn proud techfreak rather advice really looking vent understand typing everything really cathartic help process thing give advice would say irresistible dry get seem completely different wavelength obviously sort issue communication almost like speaking different language well asked date guess something right haha talking dating website last week conversation like essay ended making joke apologising joke telling worst taking compliment saying thank totally smooth cool seemed take well though blushed gon lie came see could work like someone blushed sound really interesting fun visiting stable general would something love getting pet feed horse would awesome thinking think attract people strong partly disabled new theory though going develop think valid point would find disregard pretty standard trigger hard get maybe would feel different something personal still think way shoot foot basically poke someone emotion idea unless worked deal emotion help guess big part second guessing worrying feeling realise sure much deal make help tend thinking thing could done today thing could every day little motivation rectify situation yeah break haha feel even continue debate even start mean know right cool maybe part issue particular infjs talking guise potential romance disagreeing make sense could potentially make lot sense also probably also little bit awe intellect learn much even might actually hold information long infj told beautiful really sweet way usually find thing sweet tbh great taking compliment know best respond natural reaction feel like people need understand difference emotional exchange factual thinker exchange two always mutually exclusive sometimes tell local shelter really volunteer like unless properly book volunteer make sense sometimes walk dog though guess guess would say favourite country country want visit right somewhere like live bit took yesterday still impression istp low behold got istp result clear slight moderate clear remember correctly someone asked animal shelter date would ask masochist area discussing lack ability form meaningful connection majority people find majority people stimulating enough majority northern irish leaning towards british canada home ice hockey generally seems like cool place plus close enough america without actually america yeah sure overall interest writing enjoy writing certain topic inspire constant guess issue come maybe think make sense probably blog post motivated inspired write usually one benefit way usually one get warm fuzzies bet common think still working room full unfinished project may may get around finishing day also blog find spike motivation inspiration\",\n",
      "  \"semantic_view\": \"The post reflects on the author's evolving perspective on zombies influenced by media, shares personal interactions and reflections involving INFJ and INFx personality types, and discusses complexities in romantic and interpersonal relationships, particularly regarding expectations and communication.\",\n",
      "  \"sentiment_view\": \"The emotions conveyed include curiosity, introspection, cautiousness, and a mix of interest and frustration regarding personal connections and understanding others' expectations.\",\n",
      "  \"linguistic_view\": \"The writing style is informal, conversational, and reflective, combining narrative elements with personal insights and some vague or unfinished thoughts.\"\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "with open(\"train.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "print(json.dumps(data[0], indent=2, ensure_ascii=False))\n",
    "print(json.dumps(data[1], indent=2, ensure_ascii=False))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
