#!/usr/bin/env python3
"""
eval_generated_bug_embeddings.py

Generate bugs from a model and score their code embedding similarity to
target (human) bugs vs negative (model-generated) bugs.

This helps evaluate whether a bug generator produces bugs that are stylistically
similar to human-written bugs.

Usage:
    python -m examples.bugs.eval_generated_bug_embeddings \
        --generator_model Qwen/Qwen2.5-Coder-7B-Instruct \
        --generator_base_url http://localhost:30000/v1 \
        --source_dataset bigcodebench \
        --target_dataset bugbench_human \
        --negative_datasets bugbench_qwen7b_sampled \
        --n_tasks 100
"""

from __future__ import annotations

import argparse
import asyncio
import os
import random
import time
from typing import Any, Dict, List, Optional, Set, Tuple

import numpy as np

from rllm.data.dataset import DatasetRegistry
from rllm.engine import OpenAIEngine

from examples.bugs.code_embedding import (
    CodeEmbeddingConfig,
    CodeEmbedder,
    KNNBugSimilarity,
    ReferencePool,
)
from examples.bugs_refactor.components import BugGenerator, BugGeneratorConfig


# ---------------------------
# Task schema helpers
# ---------------------------

def _get_problem(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in (
            "question",
            "instruct_prompt",
            "complete_prompt",
            "prompt",
            "text",
            "problem",
            "description",
            "code_prompt",
        ):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in (
        "question",
        "instruct_prompt",
        "complete_prompt",
        "prompt",
        "text",
        "problem",
        "description",
        "code_prompt",
    ):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_reference_solution(task: Dict[str, Any]) -> str:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("reference_solution", "canonical_solution", "solution", "code", "correct_code"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return ""


def _get_buggy_solution(task: Dict[str, Any]) -> Optional[str]:
    extra_info = task.get("extra_info", {})
    if isinstance(extra_info, dict):
        for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
            v = extra_info.get(key)
            if isinstance(v, str) and v.strip():
                return v
    for key in ("buggy_solution", "buggy_sampled_solution", "buggy", "buggy_code", "bug"):
        v = task.get(key)
        if isinstance(v, str) and v.strip():
            return v
    return None


def _get_task_id(task: Dict[str, Any]) -> Optional[str]:
    extra_info = task.get("extra_info", {})
    candidates = []
    if isinstance(extra_info, dict):
        candidates.extend([
            extra_info.get("task_id"),
            extra_info.get("uid"),
            extra_info.get("id"),
            extra_info.get("problem_id"),
            extra_info.get("instance_id"),
        ])
    candidates.extend([
        task.get("task_id"),
        task.get("uid"),
        task.get("id"),
        task.get("problem_id"),
        task.get("instance_id"),
    ])
    for v in candidates:
        if isinstance(v, (str, int)) and str(v).strip():
            return str(v).strip()
    return None


# ---------------------------
# AUROC (no sklearn dependency)
# ---------------------------

def auc_roc(y_true: List[int], y_score: List[float]) -> float:
    """AUROC via Mann–Whitney U / rank statistic."""
    if len(y_true) != len(y_score) or len(y_true) == 0:
        return float("nan")

    n_pos = sum(1 for y in y_true if y == 1)
    n_neg = len(y_true) - n_pos
    if n_pos == 0 or n_neg == 0:
        return float("nan")

    order = np.argsort(np.array(y_score))
    ranks = np.empty(len(y_score), dtype=float)
    ranks[order] = np.arange(1, len(y_score) + 1, dtype=float)

    scores_sorted = np.array(y_score)[order]
    i = 0
    while i < len(scores_sorted):
        j = i + 1
        while j < len(scores_sorted) and scores_sorted[j] == scores_sorted[i]:
            j += 1
        if j - i > 1:
            avg_rank = float(np.mean(ranks[order[i:j]]))
            ranks[order[i:j]] = avg_rank
        i = j

    sum_ranks_pos = float(sum(ranks[i] for i, y in enumerate(y_true) if y == 1))
    u_pos = sum_ranks_pos - (n_pos * (n_pos + 1)) / 2.0
    auc = u_pos / (n_pos * n_neg)
    return float(auc)


# ---------------------------
# Data loading
# ---------------------------

def load_tasks_with_solutions(dataset_name: str, split: str) -> List[Dict[str, Any]]:
    """Load tasks that have reference solutions (needed for bug generation)."""
    ds = DatasetRegistry.load_dataset(dataset_name, split)
    if ds is None:
        raise RuntimeError(f"Could not load dataset={dataset_name!r} split={split!r}")
    data = list(ds.get_data())
    # Filter to tasks with reference solutions
    with_solution = [t for t in data if _get_reference_solution(t)]
    return with_solution


def load_buggy_tasks(dataset_name: str, split: str) -> List[Dict[str, Any]]:
    """Load tasks that have buggy solutions (for building pools)."""
    ds = DatasetRegistry.load_dataset(dataset_name, split)
    if ds is None:
        raise RuntimeError(f"Could not load dataset={dataset_name!r} split={split!r}")
    data = list(ds.get_data())
    buggy = [t for t in data if _get_buggy_solution(t)]
    return buggy


def sample_tasks(tasks: List[Dict[str, Any]], n: int, rng: random.Random) -> List[Dict[str, Any]]:
    if not tasks:
        return []
    if n >= len(tasks):
        return list(tasks)
    return rng.sample(tasks, n)


def collect_task_ids(tasks: List[Dict[str, Any]]) -> Set[str]:
    """Collect all task IDs from a list of tasks."""
    s: Set[str] = set()
    for t in tasks:
        tid = _get_task_id(t)
        if tid is not None:
            s.add(tid)
    return s


def filter_to_task_ids(tasks: List[Dict[str, Any]], allowed: Set[str]) -> List[Dict[str, Any]]:
    """Filter tasks to only those with task_id in allowed set."""
    out: List[Dict[str, Any]] = []
    for t in tasks:
        tid = _get_task_id(t)
        if tid is not None and tid in allowed:
            out.append(t)
    return out


# ---------------------------
# Bug generation
# ---------------------------

async def generate_bugs_batch(
    generator: BugGenerator,
    tasks: List[Dict[str, Any]],
    n_parallel: int = 16,
) -> List[Tuple[Dict[str, Any], str, bool]]:
    """
    Generate bugs for a batch of tasks.
    
    Returns: List of (task, buggy_code, success) tuples
    """
    semaphore = asyncio.Semaphore(n_parallel)
    
    async def generate_one(task: Dict[str, Any], idx: int) -> Tuple[Dict[str, Any], str, bool]:
        async with semaphore:
            try:
                traj = await generator.generate_bug(task, f"gen_{idx}")
                buggy_code = traj.steps[0].action if traj.steps else ""
                return (task, buggy_code, True)
            except Exception as e:
                print(f"  [WARN] Failed to generate bug for task {idx}: {e}")
                return (task, "", False)
    
    coros = [generate_one(task, i) for i, task in enumerate(tasks)]
    results = await asyncio.gather(*coros)
    return list(results)


# ---------------------------
# Main
# ---------------------------

def main():
    ap = argparse.ArgumentParser("Generate bugs from a model and score embedding similarity")
    
    # Generator model config
    ap.add_argument("--generator_model", type=str, required=True,
                    help="Model to use for bug generation (e.g., Qwen/Qwen2.5-Coder-7B-Instruct)")
    ap.add_argument("--generator_base_url", type=str, default="http://localhost:30000/v1",
                    help="Base URL for generator model API")
    ap.add_argument("--generator_api_key", type=str, default=None,
                    help="API key for generator (default: OPENAI_API_KEY or 'EMPTY')")
    ap.add_argument("--generator_temperature", type=float, default=0.6)
    ap.add_argument("--generator_top_p", type=float, default=0.95)
    ap.add_argument("--generator_system_prompt", type=str, default=None)
    
    # Source dataset (where to sample tasks for bug generation)
    ap.add_argument("--source_dataset", type=str, default="bigcodebench",
                    help="Dataset to sample tasks from for bug generation")
    ap.add_argument("--source_split", type=str, default="train",
                    help="Split of source dataset")
    ap.add_argument("--n_tasks", type=int, default=100,
                    help="Number of tasks to generate bugs for")
    
    # Target dataset (human bugs to compare against)
    ap.add_argument("--target_dataset", type=str, default="bugbench_human",
                    help="Dataset with target (human) bugs for similarity comparison")
    ap.add_argument("--target_split", type=str, default="train",
                    help="Split for target pool")
    ap.add_argument("--target_pool_path", type=str, default=None,
                    help="Path to pre-computed target pool (optional)")
    
    # Negative datasets (model-generated bugs)
    ap.add_argument("--negative_datasets", type=str, nargs="*", default=None,
                    help="Datasets with negative (model) bugs (default: none)")
    ap.add_argument("--negative_split", type=str, default="train",
                    help="Split for negative pool")
    ap.add_argument("--negative_pool_path", type=str, default=None,
                    help="Path to pre-computed negative pool (optional)")
    
    # Embedding config
    ap.add_argument("--embed_model", type=str, default="voyage-code-3",
                    help="Embedding model to use")
    ap.add_argument("--embed_mode", type=str, default="buggy", choices=["diff", "buggy"],
                    help="Embed mode: 'diff' embeds unified diff, 'buggy' embeds raw buggy code")
    ap.add_argument("--include_problem", action="store_true", default=False,
                    help="Include problem description in embedding")
    ap.add_argument("--top_k", type=int, default=20,
                    help="Number of nearest neighbors for similarity scoring")
    ap.add_argument("--use_margin", action="store_true", default=True,
                    help="Use margin-based scoring (target - negative)")
    ap.add_argument("--no_use_margin", action="store_false", dest="use_margin")
    ap.add_argument("--margin_temperature", type=float, default=10.0,
                    help="Temperature for margin sigmoid")
    
    # Execution
    ap.add_argument("--n_parallel", type=int, default=16,
                    help="Number of parallel bug generation requests")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--device", type=str, default="cuda",
                    help="Device for local embedding models")
    ap.add_argument("--voyage_api_key", type=str, default=None,
                    help="Voyage API key (optional)")
    
    # Task ID matching
    ap.add_argument("--match_task_ids", action="store_true", default=True,
                    help="Match task IDs across source, target, and negative datasets for fair comparison")
    ap.add_argument("--no_match_task_ids", action="store_false", dest="match_task_ids")
    
    # Output
    ap.add_argument("--save_pool_path", type=str, default=None,
                    help="Path to save generated bugs as a pool (for future use)")
    ap.add_argument("--compare_to_pregenerated", action="store_true", default=False,
                    help="Also load pregenerated bugs from source dataset and compare scores")
    
    args = ap.parse_args()
    rng = random.Random(args.seed)
    
    # Resolve API key
    api_key = args.generator_api_key
    if api_key is None:
        api_key = os.getenv("OPENAI_API_KEY", "")
    if not api_key.strip():
        api_key = "EMPTY"
    
    print("=" * 80)
    print("Generated Bug Embedding Similarity Evaluation")
    print("=" * 80)
    print(f"Generator: {args.generator_model}")
    print(f"  base_url: {args.generator_base_url}")
    print(f"  temperature: {args.generator_temperature}, top_p: {args.generator_top_p}")
    print(f"Source: {args.source_dataset}:{args.source_split} (n={args.n_tasks})")
    print(f"Target pool: {args.target_dataset}:{args.target_split}")
    if args.negative_datasets:
        print(f"Negative pool: {args.negative_datasets}:{args.negative_split}")
    print(f"Embedding: {args.embed_model}, mode={args.embed_mode}, top_k={args.top_k}")
    print(f"Use margin: {args.use_margin}, temperature={args.margin_temperature}")
    print(f"Match task IDs: {args.match_task_ids}")
    print()
    
    # ---------------------------
    # Load all datasets first (for task ID matching)
    # ---------------------------
    print("Loading datasets...")
    
    # Load source tasks
    print(f"  Source: {args.source_dataset}:{args.source_split}...")
    source_tasks = load_tasks_with_solutions(args.source_dataset, args.source_split)
    print(f"    Found {len(source_tasks)} tasks with reference solutions")
    source_ids = collect_task_ids(source_tasks)
    print(f"    Task IDs: {len(source_ids)}")
    
    # Load target tasks (for pool building)
    target_tasks: List[Dict[str, Any]] = []
    if not args.target_pool_path:
        print(f"  Target: {args.target_dataset}:{args.target_split}...")
        target_tasks = load_buggy_tasks(args.target_dataset, args.target_split)
        print(f"    Found {len(target_tasks)} buggy tasks")
        target_ids = collect_task_ids(target_tasks)
        print(f"    Task IDs: {len(target_ids)}")
    
    # Load negative tasks (for pool building)
    neg_tasks_all: List[Dict[str, Any]] = []
    if args.use_margin and args.negative_datasets and not args.negative_pool_path:
        for dn in args.negative_datasets:
            try:
                nt = load_buggy_tasks(dn, args.negative_split)
                neg_tasks_all.extend(nt)
                print(f"    Loaded {len(nt)} from {dn}:{args.negative_split}")
            except Exception as e:
                print(f"    [WARN] Failed to load {dn}:{args.negative_split}: {e}")
        neg_ids = collect_task_ids(neg_tasks_all)
        print(f"    Negative pool task IDs: {len(neg_ids)}")
    
    print()
    
    # ---------------------------
    # Task ID matching (optional)
    # ---------------------------
    matched_task_ids: Optional[Set[str]] = None
    if args.match_task_ids:
        print("Computing task ID intersection...")
        
        # Start with source task IDs
        id_sets = [source_ids]
        id_set_names = [f"source ({len(source_ids)})"]
        
        # Add target task IDs if available
        if target_tasks:
            target_ids = collect_task_ids(target_tasks)
            id_sets.append(target_ids)
            id_set_names.append(f"target ({len(target_ids)})")
        
        # Add negative task IDs if available
        if neg_tasks_all:
            neg_ids = collect_task_ids(neg_tasks_all)
            id_sets.append(neg_ids)
            id_set_names.append(f"negative ({len(neg_ids)})")
        
        # Compute intersection
        if len(id_sets) > 1:
            matched_task_ids = set.intersection(*id_sets)
            print(f"  Intersection of {', '.join(id_set_names)}: {len(matched_task_ids)} task IDs")
            
            if len(matched_task_ids) == 0:
                print("  [WARN] Task ID intersection is empty! Disabling matching.")
                matched_task_ids = None
            else:
                # Filter all datasets to matched task IDs
                source_before = len(source_tasks)
                source_tasks = filter_to_task_ids(source_tasks, matched_task_ids)
                print(f"  Filtered source: {len(source_tasks)}/{source_before} tasks")
                
                if target_tasks:
                    target_before = len(target_tasks)
                    target_tasks = filter_to_task_ids(target_tasks, matched_task_ids)
                    print(f"  Filtered target: {len(target_tasks)}/{target_before} tasks")
                
                if neg_tasks_all:
                    neg_before = len(neg_tasks_all)
                    neg_tasks_all = filter_to_task_ids(neg_tasks_all, matched_task_ids)
                    print(f"  Filtered negative: {len(neg_tasks_all)}/{neg_before} tasks")
        else:
            print("  [INFO] Only source dataset available, skipping task ID matching")
        
        print()
    
    # ---------------------------
    # Build embedding scorer
    # ---------------------------
    print("Building embedding scorer...")
    cfg = CodeEmbeddingConfig(
        model_name=args.embed_model,
        include_problem=bool(args.include_problem),
        embed_mode=str(args.embed_mode),
        top_k=int(args.top_k),
        device=str(args.device),
        voyage_api_key=args.voyage_api_key,
        use_relative_score=bool(args.use_margin and args.negative_datasets),
        margin_temperature=float(args.margin_temperature),
    )
    embedder = CodeEmbedder(cfg)
    knn = KNNBugSimilarity(embedder, top_k=args.top_k)
    
    # Build target pool
    if args.target_pool_path:
        print(f"Loading target pool from {args.target_pool_path}...")
        target_pool = ReferencePool.load(args.target_pool_path)
        knn.target_pool = target_pool
    elif target_tasks:
        print(f"Building target pool from {len(target_tasks)} tasks...")
        knn.build_target_pool(target_tasks)
    else:
        raise RuntimeError("No target pool path or target tasks available")
    
    print(f"  Target pool size: {len(knn.target_pool) if knn.target_pool else 0}")
    
    # Build negative pool (optional)
    if args.use_margin and args.negative_datasets:
        if args.negative_pool_path:
            print(f"Loading negative pool from {args.negative_pool_path}...")
            neg_pool = ReferencePool.load(args.negative_pool_path)
            knn.negative_pool = neg_pool
        elif neg_tasks_all:
            print(f"Building negative pool from {len(neg_tasks_all)} tasks...")
            knn.build_negative_pool(neg_tasks_all)
            print(f"  Negative pool size: {len(knn.negative_pool) if knn.negative_pool else 0}")
    
    print()
    
    # ---------------------------
    # Sample source tasks for bug generation
    # ---------------------------
    print(f"Sampling {args.n_tasks} tasks for bug generation...")
    sampled_tasks = sample_tasks(source_tasks, args.n_tasks, rng)
    print(f"  Sampled {len(sampled_tasks)} tasks")
    print()
    
    # ---------------------------
    # Initialize generator
    # ---------------------------
    print("Initializing bug generator...")
    generator_engine = OpenAIEngine(
        model=args.generator_model,
        tokenizer=None,
        base_url=args.generator_base_url,
        api_key=api_key,
        max_prompt_length=8192,
        max_response_length=4096,
        sampling_params={
            "temperature": args.generator_temperature,
            "top_p": args.generator_top_p,
        },
        verbose=False,
    )
    generator = BugGenerator(
        generator_engine,
        BugGeneratorConfig(system_prompt=args.generator_system_prompt),
    )
    
    # ---------------------------
    # Generate bugs
    # ---------------------------
    print(f"\n🐛 Generating bugs for {len(sampled_tasks)} tasks (parallel={args.n_parallel})...")
    start_time = time.time()
    
    results = asyncio.run(generate_bugs_batch(generator, sampled_tasks, args.n_parallel))
    
    gen_time = time.time() - start_time
    successful = [(t, b) for t, b, s in results if s and b.strip()]
    failed = len(results) - len(successful)
    
    print(f"  Generated {len(successful)} bugs in {gen_time:.1f}s ({failed} failed)")
    print()
    
    if not successful:
        print("No bugs generated successfully. Exiting.")
        return
    
    # ---------------------------
    # Score generated bugs
    # ---------------------------
    print("🔢 Scoring generated bugs...")
    generated_scores: List[float] = []
    generated_margins: List[float] = []
    generated_target_sims: List[float] = []
    generated_negative_sims: List[float] = []
    
    for task, buggy_code in successful:
        problem = _get_problem(task)
        correct_code = _get_reference_solution(task)
        
        score, meta = knn.score_similarity(problem, buggy_code, correct_code=correct_code)
        generated_scores.append(float(score))
        
        if "margin" in meta:
            generated_margins.append(float(meta["margin"]))
        if "target_sim" in meta:
            generated_target_sims.append(float(meta["target_sim"]))
        if "negative_sim" in meta:
            generated_negative_sims.append(float(meta["negative_sim"]))
    
    # ---------------------------
    # Optionally compare to pregenerated bugs
    # ---------------------------
    pregenerated_scores: List[float] = []
    pregenerated_margins: List[float] = []
    
    if args.compare_to_pregenerated:
        print("\n📊 Comparing to pregenerated bugs from source dataset...")
        pregen_tasks = [t for t in source_tasks if _get_buggy_solution(t)]
        pregen_sampled = sample_tasks(pregen_tasks, min(len(successful), len(pregen_tasks)), rng)
        
        print(f"  Found {len(pregen_sampled)} pregenerated bugs to compare")
        
        for task in pregen_sampled:
            problem = _get_problem(task)
            buggy_code = _get_buggy_solution(task) or ""
            correct_code = _get_reference_solution(task)
            
            score, meta = knn.score_similarity(problem, buggy_code, correct_code=correct_code)
            pregenerated_scores.append(float(score))
            if "margin" in meta:
                pregenerated_margins.append(float(meta["margin"]))
    
    # ---------------------------
    # Print statistics
    # ---------------------------
    print("\n" + "=" * 80)
    print("📊 GENERATED BUG EMBEDDING SCORES")
    print("=" * 80)
    
    arr = np.array(generated_scores)
    print(f"\nGenerated bugs (n={len(arr)}):")
    print(f"  Mean:   {arr.mean():.4f}")
    print(f"  Std:    {arr.std():.4f}")
    print(f"  Median: {np.median(arr):.4f}")
    print(f"  Min:    {arr.min():.4f}")
    print(f"  Max:    {arr.max():.4f}")
    
    if generated_margins:
        marr = np.array(generated_margins)
        print(f"\n  Margins (target_sim - negative_sim):")
        print(f"    Mean:   {marr.mean():+.4f}")
        print(f"    Std:    {marr.std():.4f}")
        print(f"    Median: {np.median(marr):+.4f}")
    
    if generated_target_sims:
        tarr = np.array(generated_target_sims)
        print(f"\n  Target similarities:")
        print(f"    Mean: {tarr.mean():.4f}")
    
    if generated_negative_sims:
        narr = np.array(generated_negative_sims)
        print(f"\n  Negative similarities:")
        print(f"    Mean: {narr.mean():.4f}")
    
    # Distribution buckets
    print(f"\n  Score distribution:")
    buckets = [0, 0, 0, 0, 0]
    for score in generated_scores:
        bucket_idx = min(int(score * 5), 4)
        buckets[bucket_idx] += 1
    print(f"    [0.0-0.2]: {buckets[0]:3d}  ({100*buckets[0]/len(generated_scores):5.1f}%)")
    print(f"    [0.2-0.4]: {buckets[1]:3d}  ({100*buckets[1]/len(generated_scores):5.1f}%)")
    print(f"    [0.4-0.6]: {buckets[2]:3d}  ({100*buckets[2]/len(generated_scores):5.1f}%)")
    print(f"    [0.6-0.8]: {buckets[3]:3d}  ({100*buckets[3]/len(generated_scores):5.1f}%)")
    print(f"    [0.8-1.0]: {buckets[4]:3d}  ({100*buckets[4]/len(generated_scores):5.1f}%)")
    
    if pregenerated_scores:
        print("\n" + "-" * 40)
        prearr = np.array(pregenerated_scores)
        print(f"\nPregenerated bugs from source (n={len(prearr)}):")
        print(f"  Mean:   {prearr.mean():.4f}")
        print(f"  Std:    {prearr.std():.4f}")
        print(f"  Median: {np.median(prearr):.4f}")
        
        if pregenerated_margins:
            pmarr = np.array(pregenerated_margins)
            print(f"  Mean margin: {pmarr.mean():+.4f}")
        
        # Compare
        print("\n📈 Comparison:")
        diff = arr.mean() - prearr.mean()
        print(f"  Generated mean - Pregenerated mean = {diff:+.4f}")
        
        # AUROC: can we distinguish generated from pregenerated?
        if len(arr) > 10 and len(prearr) > 10:
            labels = [1] * len(arr) + [0] * len(prearr)
            all_scores = list(generated_scores) + list(pregenerated_scores)
            auc = auc_roc(labels, all_scores)
            print(f"\n  AUROC (generated=1 vs pregenerated=0): {auc:.4f}")
            print(f"    (>0.5 means generated bugs score HIGHER than pregenerated)")
    
    # ---------------------------
    # Save pool if requested
    # ---------------------------
    if args.save_pool_path and successful:
        print(f"\n💾 Saving generated bugs as pool to {args.save_pool_path}...")
        # Build a pool from generated bugs
        gen_tasks_for_pool = []
        for task, buggy_code in successful:
            t = dict(task)
            t["buggy_solution"] = buggy_code
            gen_tasks_for_pool.append(t)
        
        # Build and save
        save_knn = KNNBugSimilarity(embedder, top_k=args.top_k)
        save_knn.build_target_pool(gen_tasks_for_pool)
        if save_knn.target_pool:
            os.makedirs(os.path.dirname(args.save_pool_path) or '.', exist_ok=True)
            save_knn.target_pool.save(args.save_pool_path)
            print(f"  Saved {len(save_knn.target_pool)} embeddings")
    
    print("\n✅ Done!")


if __name__ == "__main__":
    main()
