{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a53137",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import re\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "712f98d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import logging\n",
    "from typing import Dict, List, Tuple\n",
    "\n",
    "import numpy as np\n",
    "import pytrec_eval\n",
    "\n",
    "def evaluate(\n",
    "    qrels: Dict[str, Dict[str, int]],\n",
    "    results: Dict[str, Dict[str, float]],\n",
    "    k_values: List[int],\n",
    ") -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:\n",
    "    \"\"\"\n",
    "    仿照 beir.retrieval.evaluation.EvaluateRetrieval.evaluate 编写的评估函数。\n",
    "    \"\"\"\n",
    "    ndcg = {}\n",
    "    _map = {}\n",
    "    recall = {}\n",
    "    precision = {}\n",
    "\n",
    "    for k in k_values:\n",
    "        ndcg[f\"NDCG@{k}\"] = 0.0\n",
    "        _map[f\"MAP@{k}\"] = 0.0\n",
    "        recall[f\"Recall@{k}\"] = 0.0\n",
    "        precision[f\"P@{k}\"] = 0.0\n",
    "\n",
    "    map_string = \"map_cut.\" + \",\".join([str(k) for k in k_values])\n",
    "    ndcg_string = \"ndcg_cut.\" + \",\".join([str(k) for k in k_values])\n",
    "    recall_string = \"recall.\" + \",\".join([str(k) for k in k_values])\n",
    "    precision_string = \"P.\" + \",\".join([str(k) for k in k_values])\n",
    "    evaluator = pytrec_eval.RelevanceEvaluator(\n",
    "        qrels, {map_string, ndcg_string, recall_string, precision_string}\n",
    "    )\n",
    "    scores = evaluator.evaluate(results)\n",
    "\n",
    "    for query_id in scores.keys():\n",
    "        for k in k_values:\n",
    "            ndcg[f\"NDCG@{k}\"] += scores[query_id][\"ndcg_cut_\" + str(k)]\n",
    "            _map[f\"MAP@{k}\"] += scores[query_id][\"map_cut_\" + str(k)]\n",
    "            recall[f\"Recall@{k}\"] += scores[query_id][\"recall_\" + str(k)]\n",
    "            precision[f\"P@{k}\"] += scores[query_id][\"P_\" + str(k)]\n",
    "\n",
    "    for k in k_values:\n",
    "        ndcg[f\"NDCG@{k}\"] = round(ndcg[f\"NDCG@{k}\"] / len(scores), 5)\n",
    "        _map[f\"MAP@{k}\"] = round(_map[f\"MAP@{k}\"] / len(scores), 5)\n",
    "        recall[f\"Recall@{k}\"] = round(recall[f\"Recall@{k}\"] / len(scores), 5)\n",
    "        precision[f\"P@{k}\"] = round(precision[f\"P@{k}\"] / len(scores), 5)\n",
    "\n",
    "    # for eval_metric in [ndcg, _map, recall, precision]:\n",
    "    #     logging.info(\"\\n\")\n",
    "    #     for k, v in eval_metric.items():\n",
    "    #         logging.info(f\"{k}: {v:.4f}\")\n",
    "\n",
    "    return recall\n",
    "\n",
    "\n",
    "def load_gt(gt_path: str) -> Dict[str, Dict[str, int]]:\n",
    "    \"\"\"\n",
    "    加载 ground-truth npy 文件并转换为 pytrec_eval 所需的 qrels 格式。\n",
    "    查询ID将使用其在npy文件中的索引（0, 1, 2, ...）。\n",
    "    \"\"\"\n",
    "    gt_data = np.load(gt_path, allow_pickle=True)\n",
    "    qrels = {}\n",
    "    for i, gt_list in enumerate(gt_data):\n",
    "        query_id = str(i)\n",
    "        qrels[query_id] = {}\n",
    "        for passage_id in gt_list:\n",
    "            qrels[query_id][str(passage_id)] = 1  # 假设相关性得分为1\n",
    "    return qrels\n",
    "\n",
    "\n",
    "def load_results(results_path: str) -> Dict[str, Dict[str, float]]:\n",
    "    \"\"\"\n",
    "    加载检索结果的tsv文件并转换为 pytrec_eval 所需的 results 格式。\n",
    "    - 如果文件有4列 (query_id, passage_id, rank, score)，则使用第四列的分数。\n",
    "    - 如果文件只有3列 (query_id, passage_id, rank)，则使用 1/rank 作为分数。\n",
    "    \"\"\"\n",
    "    results = {}\n",
    "    with open(results_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        reader = csv.reader(f, delimiter=\"\\t\")\n",
    "        for row in reader:\n",
    "            if not row: continue # Skip empty lines\n",
    "\n",
    "            query_id, passage_id = row[0], row[1]\n",
    "            \n",
    "            if query_id not in results:\n",
    "                results[query_id] = {}\n",
    "            \n",
    "            # 判断使用真实分数还是生成代理分数\n",
    "            if len(row) == 4:\n",
    "                score = float(row[3])\n",
    "            elif len(row) == 3:\n",
    "                rank = int(row[2])\n",
    "                score = 1.0 / rank\n",
    "            else:\n",
    "                logging.warning(f\"Skipping malformed line with {len(row)} columns: {row}\")\n",
    "                continue\n",
    "            \n",
    "            results[query_id][passage_id] = score\n",
    "            \n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3dcf3b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# HELLO PROFESSOR THIS IS FOR READING THE RESULTS OF THE MVRHNSW\n",
    "import os\n",
    "import re\n",
    "import pandas as pd\n",
    "directory = \"/data/lijunlin/sigmod2025-results/multi-hnsw-result\"\n",
    "keyword = \"clip-multi-clustering\"\n",
    "\n",
    "# ----------------------------\n",
    "# File pairing\n",
    "# ----------------------------\n",
    "all_files = os.listdir(directory)\n",
    "txt_files = [\n",
    "    f for f in all_files\n",
    "    if f.endswith(\".txt\") and keyword in f and \"metadata\" not in f\n",
    "]\n",
    "\n",
    "file_pairs = {}\n",
    "for f in txt_files:\n",
    "    if f.endswith(\"_summary.txt\"):\n",
    "        base_name = f[:-len(\"_summary.txt\")]\n",
    "        file_pairs.setdefault(base_name, {})[\"summary\"] = os.path.join(directory, f)\n",
    "    else:\n",
    "        base_name = f[:-len(\".txt\")]\n",
    "        file_pairs.setdefault(base_name, {})[\"results\"] = os.path.join(directory, f)\n",
    "\n",
    "# ----------------------------\n",
    "# Helpers\n",
    "# ----------------------------\n",
    "def load_results_tsv(path: str) -> pd.DataFrame:\n",
    "    # Results: query_id, passage_id, rank, score (tab-separated, no header)\n",
    "    return pd.read_csv(\n",
    "        path,\n",
    "        sep=\"\\t\",\n",
    "        header=None,\n",
    "        names=[\"query_id\", \"passage_id\", \"rank\", \"score\"],\n",
    "        dtype={\"query_id\": int, \"passage_id\": int, \"rank\": int, \"score\": float}\n",
    "    )\n",
    "\n",
    "def extract_avg_retrieval_time(summary_text: str) -> float | None:\n",
    "    # Robust parser for avg retrieval time (ms)\n",
    "    patterns = [\n",
    "        r\"Average\\s+Query\\s+Time:\\s*([\\d.]+)\\s*ms\",\n",
    "        r\"Avg\\s*retrieval\\s*time\\s*[:=]\\s*([\\d.]+)\\s*ms\",\n",
    "        r\"retrieval_time_single_query_average\\(ms\\)\\s*[:=]\\s*([\\d.]+)\",\n",
    "        r\"avg.*?ms\\s*[:=]\\s*([\\d.]+)\",\n",
    "    ]\n",
    "    for pat in patterns:\n",
    "        m = re.search(pat, summary_text, flags=re.IGNORECASE)\n",
    "        if m:\n",
    "            try:\n",
    "                return float(m.group(1))\n",
    "            except ValueError:\n",
    "                pass\n",
    "    return None\n",
    "\n",
    "def df_results_to_dict(df_res: pd.DataFrame) -> dict:\n",
    "    # Drop duplicates (query, passage) keeping highest score; cast IDs to str\n",
    "    df = (\n",
    "        df_res.sort_values('score', ascending=False)\n",
    "              .drop_duplicates(['query_id', 'passage_id'])\n",
    "              .astype({'query_id': str, 'passage_id': str})\n",
    "    )\n",
    "    return (\n",
    "        df.groupby('query_id')\n",
    "          .apply(lambda g: dict(zip(g['passage_id'], g['score'])))\n",
    "          .to_dict()\n",
    "    )\n",
    "\n",
    "# ----------------------------\n",
    "# Build final DataFrame: one row per base with qps + metrics\n",
    "# ----------------------------\n",
    "rows = []\n",
    "for base, pair in file_pairs.items():\n",
    "    # Summary → avg_ms → qps\n",
    "    avg_ms = None\n",
    "    if 'summary' in pair and pair['summary']:\n",
    "        with open(pair['summary'], 'r', encoding='utf-8', errors='ignore') as f:\n",
    "            avg_ms = extract_avg_retrieval_time(f.read())\n",
    "    qps = (1000.0 / avg_ms) if (avg_ms is not None and avg_ms > 0) else None\n",
    "\n",
    "    # Results → metrics via evaluate\n",
    "    metrics = {}\n",
    "    if 'results' in pair and pair['results']:\n",
    "        df_res = load_results_tsv(pair['results'])\n",
    "        results_dict = df_results_to_dict(df_res)\n",
    "        metrics = evaluate(qrels=qrels_data, results=results_dict, k_values=K_VALUES) or {}\n",
    "\n",
    "    row = {'Algorithm': \"MVRHNSW\", 'QPS': qps, \"Dataset\": \"DBpedia-entity\", \"Recall\": metrics[\"Recall@10\"]}\n",
    "    # row.update(metrics)  # adds e.g. 'Recall@10', etc.\n",
    "    rows.append(row)\n",
    "\n",
    "mvrhnsw_df = pd.DataFrame(rows).reset_index(drop=True)\n",
    "mvrhnsw_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3017d8c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_limit(row):\n",
    "    identifier = row['Identifier']\n",
    "    name = row['Algorithm']\n",
    "    \n",
    "    # Default to None\n",
    "    match = None\n",
    "\n",
    "    if name == \"IGP\":\n",
    "        # Extract probe_topk\n",
    "        match = re.search(r'probe_topk_(\\d+)', identifier)\n",
    "    \n",
    "    if match:\n",
    "        return int(match.group(1))\n",
    "    else:\n",
    "        return 0  # If pattern not found, assume safe to keep\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334f8ca8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Define multiple keyword pairs with names ---\n",
    "keyword_pairs = [\n",
    "    (\"clip-multi-clustering-retrieval-IGP\", \"clip-multi-clustering-IGP\", \"IGP\"),\n",
    "    (\"clip-multi-clustering-retrieval-dessert\", \"clip-multi-clustering-dessert\", \"Dessert\"),\n",
    "    (\"clip-multi-clustering-retrieval-plaid\", \"clip-multi-clustering-plaid\", \"Plaid\"),\n",
    "]\n",
    "\n",
    "all_results = []\n",
    "\n",
    "for KEYWORD_JSON, KEYWORD_TSV, name in keyword_pairs:\n",
    "    # match = re.search(r'top\\d+-.*topk_(\\d+)', identifier)\n",
    "    # if match:\n",
    "    #     topk_value = int(match.group(1))\n",
    "    #     if topk_value > 10000:\n",
    "    #         logging.info(f\"Skipping {identifier} because topk={topk_value} > 10000\")\n",
    "    #         continue\n",
    "    logging.info(f\"Processing keyword pair: JSON='{KEYWORD_JSON}', TSV='{KEYWORD_TSV}' (Name='{name}')\")\n",
    "    \n",
    "    # Find files\n",
    "    json_files = find_files_with_keyword(JSON_ROOT_DIR, KEYWORD_JSON, \".json\")\n",
    "    tsv_files = [\n",
    "        os.path.join(TSV_ROOT_DIR, f)\n",
    "        for f in os.listdir(TSV_ROOT_DIR)\n",
    "        if f.startswith(KEYWORD_TSV + \"-\") and f.endswith(\".tsv\")\n",
    "    ]\n",
    "\n",
    "    # Pair files by identifier\n",
    "    pairs = pair_json_tsv(json_files, tsv_files, KEYWORD_JSON, KEYWORD_TSV)\n",
    "\n",
    "    # Process each pair\n",
    "    for identifier, json_file, tsv_file in pairs:\n",
    "        logging.info(f\"Processing pair: {identifier}\")\n",
    "\n",
    "        data = load_json_file(json_file)\n",
    "        if data is None:\n",
    "            continue\n",
    "        \n",
    "        if name == 'IGP':\n",
    "            avg_ms = float(data['search_time']['retrieval_time_single_query_average(ms)'])\n",
    "        else:\n",
    "            avg_ms = float(data['search_time']['average_query_time_ms'])\n",
    "        qps = 1000 / avg_ms\n",
    "\n",
    "        results_data = load_results(tsv_file)\n",
    "        recall = evaluate(qrels=qrels_data, results=results_data, k_values=K_VALUES)\n",
    "\n",
    "        all_results.append({\n",
    "            \"Identifier\":identifier,\n",
    "            'Recall': recall['Recall@10'],\n",
    "            'QPS': qps,\n",
    "            'Algorithm': name,  # <-- Add name here\n",
    "            \"Dataset\": \"DBpedia-entity\"\n",
    "        })\n",
    "\n",
    "# Convert to DataFrame\n",
    "df_all_results = pd.DataFrame(all_results)\n",
    "\n",
    "df_all_results['limit_value'] = df_all_results.apply(extract_limit, axis=1)\n",
    "\n",
    "# Filter rows where limit_value <= 10000\n",
    "df_filtered = df_all_results[df_all_results['limit_value'] <= 1000].copy()\n",
    "\n",
    "# Drop the temporary column\n",
    "df_filtered.drop(columns=['limit_value'], inplace=True)\n",
    "\n",
    "\n",
    "df_sorted = df_filtered.groupby(\"Algorithm\", group_keys=False).apply(\n",
    "    lambda x: x.sort_values(by=\"Recall\", ascending=True)\n",
    ")\n",
    "\n",
    "df_sorted\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55854b1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "my_sorted_recall_list = [0.845,0.862,0.876,0.874,0.877,0.874,0.875,0.856,0.857,0.867,0.88]\n",
    "my_sorted_qps_list = [774.494,745.388,719.843,695.595,719.843,695.595,646.601,566.891,489.971,422.421,334.579]\n",
    "df_temp = pd.DataFrame({\n",
    "    'Recall': my_sorted_recall_list,\n",
    "    'QPS': my_sorted_qps_list,\n",
    "    'Algorithm': 'Multi-HNSW',\n",
    "    'Dataset': 'DBpedia-entity'\n",
    "})\n",
    "my_sorted_recall_list = [0.259,0.4,0.472,0.533,0.569,0.667,0.719,0.742,0.773,0.769,0.772]\n",
    "my_sorted_qps_list = [11179.2,8459.66,7024,7007.71,6994.57,4917.48,4624.73,3934.13,3012.67,3141.93,3082.59]\n",
    "df_temp2 = pd.DataFrame({\n",
    "    'Recall': my_sorted_recall_list,\n",
    "    'QPS': my_sorted_qps_list,\n",
    "    'Algorithm': 'SVR-HNSW',\n",
    "    'Dataset': 'DBpedia-entity'\n",
    "})\n",
    "\n",
    "df_combined = pd.concat([df_filtered, df_temp], ignore_index=True)\n",
    "df_combined = pd.concat([df_combined, df_temp2], ignore_index=True)\n",
    "\n",
    "df_combined = df_combined.drop(columns=['Identifier'])\n",
    "df_combined = pd.concat([df_combined, mvrhnsw_df], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac0eebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "df =df_combined[df_combined['QPS'] <= 4000] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2de5c452",
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithms = [\"IGP\", \"Dessert\", \"Plaid\", \"Multi-HNSW\", \"MVRHNSW\", 'SVR-HNSW']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3547bb5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Set a clean style for the plots\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "# Define a consistent color/style map for algorithms\n",
    "markers = ['o', 's', 'D', 'v', '^', '>', '<', 'p', '*', 'h', 'H', '+', 'x']\n",
    "\n",
    "# Use a high-contrast qualitative colormap (e.g., 'tab10')\n",
    "from matplotlib import colormaps as cm\n",
    "cmap = cm.get_cmap('tab10')  # tab10 is great for distinct colors\n",
    "color_list = list(getattr(cmap, 'colors', [cmap(i / 10.0) for i in range(10)]))\n",
    "\n",
    "algorithm_map = {algo: {'marker': markers[i % len(markers)],\n",
    "                        'color': color_list[i % len(color_list)]}\n",
    "                 for i, algo in enumerate(algorithms)}\n",
    "\n",
    "# Assuming you want to display up to 6 plots (2 rows, 3 cols)\n",
    "fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(15, 10))\n",
    "axes = axes.flatten()\n",
    "\n",
    "# Ensure QPS column exists (assuming 'df' is loaded outside this block)\n",
    "if 'QPS' not in df.columns and 'Time (msec)' in df.columns:\n",
    "    df['QPS'] = 1000.0 / df['Time (msec)']\n",
    "\n",
    "for i, ds in enumerate(datasets):\n",
    "    if i >= len(axes): break # Safety break\n",
    "    ax = axes[i]\n",
    "    df_subset = df[df['Dataset'] == ds]\n",
    "    \n",
    "    for algo in algorithms:\n",
    "        data = df_subset[df_subset['Algorithm'] == algo].copy()\n",
    "        \n",
    "        # Sort by Recall to make lines smooth\n",
    "        data = data.sort_values(by='Recall')\n",
    "        style = algorithm_map[algo]\n",
    "        ax.plot(\n",
    "            data['Recall'],\n",
    "            data['QPS'],\n",
    "            label=algo,\n",
    "            color=style['color'],\n",
    "            marker=style['marker'],\n",
    "            linestyle='-',\n",
    "            markersize=5,   # INCREASED MARKER SIZE\n",
    "            linewidth=3     # INCREASED LINE WIDTH\n",
    "        )\n",
    "    \n",
    "    # Logarithmic Y-axis\n",
    "    # ax.set_yscale('log')\n",
    "    \n",
    "    # Manually set X-axis limits to 0 to 1.05 for better clarity in the relevant range\n",
    "    ax.set_xlim(0, 1.05) \n",
    "    \n",
    "    # Bold axis labels\n",
    "    if i % 3 == 0:\n",
    "        ax.set_ylabel('QPS', fontsize=14, fontweight='bold')\n",
    "    if i >= 3:\n",
    "        ax.set_xlabel('Recall', fontsize=14, fontweight='bold')\n",
    "    \n",
    "    # Bold tick labels\n",
    "    ax.tick_params(axis='both', labelsize=12, width=1.5)\n",
    "    \n",
    "    # Subplot label\n",
    "    ax.text(\n",
    "        0.5, -0.25,\n",
    "        subplot_labels.get(ds, ds),\n",
    "        transform=ax.transAxes,\n",
    "        ha='center',\n",
    "        fontsize=14,\n",
    "        fontweight='bold'\n",
    "    )\n",
    "\n",
    "# Legend at top (adjusted location)\n",
    "handles, labels = axes[0].get_legend_handles_labels()\n",
    "fig.legend(\n",
    "    handles,\n",
    "    labels,\n",
    "    loc='upper center',\n",
    "    bbox_to_anchor=(0.5, 1.02), # Slightly lower for tighter layout\n",
    "    ncol=len(algorithms),\n",
    "    fontsize=13, # Slightly larger font\n",
    "    frameon=False\n",
    ")\n",
    "\n",
    "# Main title\n",
    "# Moved title to a safer place to avoid collision with the legend\n",
    "fig.suptitle('Query Performance Comparison', fontsize=18, y=1.06, fontweight='bold') \n",
    "\n",
    "# Hide unused subplots\n",
    "for j in range(len(datasets), len(axes)):\n",
    "    axes[j].axis('off')\n",
    "\n",
    "# Use standard tight_layout\n",
    "plt.tight_layout(rect=[0, 0, 1, 1]) # Standard tight_layout, title is outside\n",
    "fig.savefig('query_performance_recreation_better.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
