{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c586c0fa-9ba9-42f0-a36f-9d96628a5d89",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import faiss\n",
    "import time\n",
    "from math import log2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "aaf8bb5a-89f8-48cf-99cb-f51b632efb1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----------------------------------------------------\n",
    "# 0. Column config (adapt if needed)\n",
    "# ----------------------------------------------------\n",
    "infile = \"./dataset/finekb_cases_test.parquet\"\n",
    "infile_train = \"./dataset/finekb_cases_train_clustered.parquet\"\n",
    "infile_kb = \"./dataset/finekb_kb.parquet\"\n",
    "\n",
    "CASE_ID_COL = \"case_id\"\n",
    "KB_ID_COL = \"kb_id\"\n",
    "CLUSTER_ID_COL = \"cluster_id\"\n",
    "CASE_INDEX_EMB_COL = \"embed_summary\"\n",
    "CASE_SEARCH_EMB_COL = \"embed_summary\"\n",
    "MIN_CLUSTER_SIZE = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "effcf7b8-84c9-4746-a7c6-5d18189273f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded case dataset: (11415, 5)\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------------------------------\n",
    "# 1. Load your TRAIN dataframe\n",
    "# ----------------------------------------------------\n",
    "df_train  = pd.read_parquet(infile_train)\n",
    "df_train  = df_train.dropna(subset=['kb_id'])\n",
    "\n",
    "print(\"Loaded case dataset:\", df_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "86e8dcdc-ce3e-4fb2-8c40-19f27749587b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>case_id</th>\n",
       "      <th>issue_type</th>\n",
       "      <th>kb_id</th>\n",
       "      <th>embed_summary</th>\n",
       "      <th>cluster_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>remote_access</td>\n",
       "      <td>11</td>\n",
       "      <td>[0.005332942120730877, -0.003665205556899309, ...</td>\n",
       "      <td>11_c0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>info</td>\n",
       "      <td>102</td>\n",
       "      <td>[0.00189633306581527, 0.012942066416144371, 0....</td>\n",
       "      <td>102_c5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>contract</td>\n",
       "      <td>35</td>\n",
       "      <td>[0.0068122390657663345, 0.0030071106739342213,...</td>\n",
       "      <td>35_c3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>fan</td>\n",
       "      <td>158</td>\n",
       "      <td>[0.013492287136614323, -0.013519312255084515, ...</td>\n",
       "      <td>158_c1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>6</td>\n",
       "      <td>contract</td>\n",
       "      <td>63</td>\n",
       "      <td>[-0.005820220801979303, 0.008195333182811737, ...</td>\n",
       "      <td>63_c1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   case_id     issue_type  kb_id  \\\n",
       "0        0  remote_access     11   \n",
       "1        1           info    102   \n",
       "2        3       contract     35   \n",
       "3        4            fan    158   \n",
       "4        6       contract     63   \n",
       "\n",
       "                                       embed_summary cluster_id  \n",
       "0  [0.005332942120730877, -0.003665205556899309, ...      11_c0  \n",
       "1  [0.00189633306581527, 0.012942066416144371, 0....     102_c5  \n",
       "2  [0.0068122390657663345, 0.0030071106739342213,...      35_c3  \n",
       "3  [0.013492287136614323, -0.013519312255084515, ...     158_c1  \n",
       "4  [-0.005820220801979303, 0.008195333182811737, ...      63_c1  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "32374639-a9ed-48f4-86fe-01179b117c64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded case dataset: (2849, 4)\n"
     ]
    }
   ],
   "source": [
    "# ----------------------------------------------------\n",
    "# 1. Load your TEST dataframe\n",
    "# ----------------------------------------------------\n",
    "df_test  = pd.read_parquet(infile)\n",
    "df_test  = df_test.dropna(subset=['kb_id'])\n",
    "\n",
    "print(\"Loaded case dataset:\", df_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1f44fab6-4973-4dce-9e58-d69f553969bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>case_id</th>\n",
       "      <th>issue_type</th>\n",
       "      <th>kb_id</th>\n",
       "      <th>embed_summary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>remote_access</td>\n",
       "      <td>53</td>\n",
       "      <td>[-0.0003725903807207942, -0.008335007354617119...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>contract</td>\n",
       "      <td>27</td>\n",
       "      <td>[0.002286486327648163, 0.00011503534915391356,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>software</td>\n",
       "      <td>38</td>\n",
       "      <td>[0.01094728522002697, 0.0060193841345608234, 0...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>contract</td>\n",
       "      <td>33</td>\n",
       "      <td>[-0.003893519751727581, -0.001837620628066361,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>memory</td>\n",
       "      <td>103</td>\n",
       "      <td>[-0.007071544881910086, 0.0033897554967552423,...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   case_id     issue_type kb_id  \\\n",
       "0        0  remote_access    53   \n",
       "1        1       contract    27   \n",
       "2        2       software    38   \n",
       "3        3       contract    33   \n",
       "5        5         memory   103   \n",
       "\n",
       "                                       embed_summary  \n",
       "0  [-0.0003725903807207942, -0.008335007354617119...  \n",
       "1  [0.002286486327648163, 0.00011503534915391356,...  \n",
       "2  [0.01094728522002697, 0.0060193841345608234, 0...  \n",
       "3  [-0.003893519751727581, -0.001837620628066361,...  \n",
       "5  [-0.007071544881910086, 0.0033897554967552423,...  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4796ef57-a221-4363-a9b3-5f3b5d68aaa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KB cluster centroid matrix shape: (462, 4096)\n",
      "FAISS index size (num KB clusters) = 462\n",
      "Unique test cases: 2201\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "940a9e623ade4b53a0f3de9f9a82d1b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating FAISS KB-clusters:   0%|                                                                           …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===== FAISS(case_emb → KB_cluster_centroid_emb) =====\n",
      "Recall@3: 63.11\n",
      "Recall@5: 73.83\n",
      "MRR:      0.5222\n",
      "nDCG@5:   56.22\n",
      "\n",
      "Latency:\n",
      "avg_ms:   0.353\n",
      "p50_ms:   0.346\n",
      "p95_ms:   0.371\n",
      "QPS:      2833.09\n",
      "n_queries:2201\n",
      "\n",
      "Saved results to: ./results/cluster_embed_summary_search_embed_summary_faiss_case_cluster_results.csv\n",
      "               method   Recall@3   Recall@5       MRR     nDCG@5    avg_ms  \\\n",
      "0  FAISS_case_cluster  63.107678  73.830077  0.522175  56.220317  0.352971   \n",
      "\n",
      "     p50_ms    p95_ms          qps  n_queries  \n",
      "0  0.345883  0.371443  2833.091293       2201  \n"
     ]
    }
   ],
   "source": [
    "# -------------------------------------------------------------------\n",
    "# 1. Helper: L2-normalization\n",
    "# -------------------------------------------------------------------\n",
    "def l2_normalize(mat: np.ndarray) -> np.ndarray:\n",
    "    norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12\n",
    "    return mat / norms\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------------\n",
    "# 2. Build KB \"cluster\" index from TRAIN / historical cases\n",
    "#    Group by cluster_id, drop small clusters, avg CASE_INDEX_EMB_COL\n",
    "# -------------------------------------------------------------------\n",
    "df_train = df_train.copy()\n",
    "df_train[KB_ID_COL] = df_train[KB_ID_COL].astype(int)\n",
    "\n",
    "# 2.1 Filter out small clusters: each cluster must have >= MIN_CLUSTER_SIZE rows\n",
    "df_train_filtered = (\n",
    "    df_train\n",
    "    .groupby(CLUSTER_ID_COL)\n",
    "    .filter(lambda g: len(g) >= MIN_CLUSTER_SIZE)\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "# 2.2 Compute centroids per remaining cluster_id\n",
    "df_kb_centroids = (\n",
    "    df_train_filtered\n",
    "    .groupby(CLUSTER_ID_COL)\n",
    "    .agg({\n",
    "        CASE_INDEX_EMB_COL: lambda rows: np.mean(np.vstack(rows.values), axis=0),\n",
    "        KB_ID_COL: \"first\",   # all rows in a cluster share the same kb_id\n",
    "    })\n",
    "    .reset_index()\n",
    "    .rename(columns={CASE_INDEX_EMB_COL: \"avg_emb\"})\n",
    ")\n",
    "\n",
    "# For each indexed vector, we keep the corresponding kb_id\n",
    "kb_ids = df_kb_centroids[KB_ID_COL].to_numpy()                     # (num_clusters,)\n",
    "kb_emb_list = df_kb_centroids[\"avg_emb\"].tolist()\n",
    "kb_emb_matrix = np.vstack(kb_emb_list).astype(\"float32\")           # (num_clusters, dim)\n",
    "\n",
    "dim = kb_emb_matrix.shape[1]\n",
    "print(\"KB cluster centroid matrix shape:\", kb_emb_matrix.shape)\n",
    "\n",
    "# Normalize centroids for cosine similarity via inner product\n",
    "kb_emb_matrix = l2_normalize(kb_emb_matrix)\n",
    "\n",
    "# FAISS cosine index\n",
    "index = faiss.IndexFlatIP(dim)\n",
    "index.add(kb_emb_matrix)\n",
    "\n",
    "print(\"FAISS index size (num KB clusters) =\", index.ntotal)\n",
    "\n",
    "# FAISS internal idx → kb_id mapping\n",
    "idx2kb_id = {i: int(kb_id) for i, kb_id in enumerate(kb_ids)}\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------------\n",
    "# 3. Build TEST mappings\n",
    "# -------------------------------------------------------------------\n",
    "df_test = df_test.copy()\n",
    "df_test[CASE_ID_COL] = df_test[CASE_ID_COL].astype(int)\n",
    "df_test[KB_ID_COL] = df_test[KB_ID_COL].astype(int)\n",
    "\n",
    "# case_id → list of gold KB ids\n",
    "case_to_kbs = (\n",
    "    df_test.groupby(CASE_ID_COL)[KB_ID_COL]\n",
    "    .apply(list)\n",
    "    .to_dict()\n",
    ")\n",
    "\n",
    "# case_id → case vector (avg if multiple rows)\n",
    "case_to_emb = (\n",
    "    df_test.groupby(CASE_ID_COL)[CASE_SEARCH_EMB_COL]\n",
    "    .apply(lambda rows: np.mean(np.vstack(rows.values), axis=0))\n",
    "    .to_dict()\n",
    ")\n",
    "\n",
    "unique_case_ids = list(case_to_kbs.keys())\n",
    "print(\"Unique test cases:\", len(unique_case_ids))\n",
    "\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------------\n",
    "# 4. Evaluation: case_emb → KB_cluster_centroid_emb (distinct KB IDs)\n",
    "# -------------------------------------------------------------------\n",
    "def evaluate_case_to_kbcentroid_faiss(k: int = 20):\n",
    "    recall3_list, recall5_list = [], []\n",
    "    mrr_list, ndcg5_list = [], []\n",
    "    latencies = []\n",
    "\n",
    "    for case_id in tqdm(unique_case_ids, desc=\"Evaluating FAISS KB-clusters\", ncols=150):\n",
    "        gold_kbs = set(case_to_kbs[case_id])\n",
    "\n",
    "        # Query embedding\n",
    "        q_emb = np.array(case_to_emb[case_id], dtype=\"float32\").reshape(1, -1)\n",
    "        q_emb = l2_normalize(q_emb)\n",
    "\n",
    "        t0 = time.perf_counter()\n",
    "        D, I = index.search(q_emb, k)\n",
    "        t1 = time.perf_counter()\n",
    "\n",
    "        latencies.append(t1 - t0)\n",
    "\n",
    "        # map each FAISS returned index → KB IDs\n",
    "        retrieved_kb_ids = [int(kb_ids[idx]) for idx in I[0]]\n",
    "\n",
    "        # ---- NEW: enforce distinct KB IDs while preserving rank order ----\n",
    "        seen = set()\n",
    "        distinct_kb_ids = []\n",
    "        for kb in retrieved_kb_ids:\n",
    "            if kb not in seen:\n",
    "                distinct_kb_ids.append(kb)\n",
    "                seen.add(kb)\n",
    "        # -----------------------------------------------------------------\n",
    "\n",
    "        # Compute gold ranks in DISTINCT list\n",
    "        gold_ranks = [\n",
    "            distinct_kb_ids.index(g) + 1\n",
    "            for g in gold_kbs\n",
    "            if g in distinct_kb_ids\n",
    "        ]\n",
    "        best_rank = min(gold_ranks) if gold_ranks else None\n",
    "\n",
    "        recall3_list.append(1 if best_rank and best_rank <= 3 else 0)\n",
    "        recall5_list.append(1 if best_rank and best_rank <= 5 else 0)\n",
    "        mrr_list.append(1 / best_rank if best_rank else 0)\n",
    "        ndcg5_list.append(1 / log2(best_rank + 1) if best_rank and best_rank <= 5 else 0)\n",
    "\n",
    "    # Latency stats\n",
    "    latencies = np.array(latencies)\n",
    "    total_time = latencies.sum()\n",
    "    n_queries = len(latencies)\n",
    "\n",
    "    latency_stats = {\n",
    "        \"avg_ms\": float(latencies.mean() * 1000),\n",
    "        \"p50_ms\": float(np.percentile(latencies, 50) * 1000),\n",
    "        \"p95_ms\": float(np.percentile(latencies, 95) * 1000),\n",
    "        \"qps\": float(n_queries / total_time),\n",
    "        \"n_queries\": int(n_queries),\n",
    "    }\n",
    "\n",
    "    metrics = {\n",
    "        \"Recall@3\": np.mean(recall3_list) * 100,\n",
    "        \"Recall@5\": np.mean(recall5_list) * 100,\n",
    "        \"MRR\":      np.mean(mrr_list),\n",
    "        \"nDCG@5\":   np.mean(ndcg5_list) * 100,\n",
    "    }\n",
    "\n",
    "    return metrics, latency_stats\n",
    "\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------------\n",
    "# 5. Run evaluation\n",
    "# -------------------------------------------------------------------\n",
    "results, latency = evaluate_case_to_kbcentroid_faiss(k=20)\n",
    "\n",
    "print(\"\\n===== FAISS(case_emb → KB_cluster_centroid_emb) =====\")\n",
    "print(f\"Recall@3: {results['Recall@3']:.2f}\")\n",
    "print(f\"Recall@5: {results['Recall@5']:.2f}\")\n",
    "print(f\"MRR:      {results['MRR']:.4f}\")\n",
    "print(f\"nDCG@5:   {results['nDCG@5']:.2f}\")\n",
    "\n",
    "print(\"\\nLatency:\")\n",
    "print(f\"avg_ms:   {latency['avg_ms']:.3f}\")\n",
    "print(f\"p50_ms:   {latency['p50_ms']:.3f}\")\n",
    "print(f\"p95_ms:   {latency['p95_ms']:.3f}\")\n",
    "print(f\"QPS:      {latency['qps']:.2f}\")\n",
    "print(f\"n_queries:{latency['n_queries']}\")\n",
    "\n",
    "\n",
    "# -------------------------------------------------------------------\n",
    "# 6. Save CSV\n",
    "# -------------------------------------------------------------------\n",
    "row = {\"method\": \"FAISS_case_cluster\", **results, **latency}\n",
    "df_out = pd.DataFrame([row])\n",
    "csv_path = (\n",
    "    f\"./results/cluster_{CASE_INDEX_EMB_COL}_search_{CASE_SEARCH_EMB_COL}\"\n",
    "    \"_faiss_case_cluster_results.csv\"\n",
    ")\n",
    "df_out.to_csv(csv_path, index=False)\n",
    "\n",
    "print(\"\\nSaved results to:\", csv_path)\n",
    "print(df_out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f48e73b0-57de-42ea-8a0c-caa66b7bf84d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "(gpu-env) Python",
   "language": "python",
   "name": "conda-env-gpu-env-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
