# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import pickle
import resource
import tempfile
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Tuple

import faiss  # @manual=//faiss/python:pyfaiss
print("Using", faiss.get_compile_options())
import numpy as np
import pandas as pd
from faiss.contrib.datasets import (  # @manual=//faiss/contrib:faiss_contrib
    Dataset,
    SyntheticDataset,
)

# Try to import Annoy with fallback
try:
    from annoy import AnnoyIndex
    ANNOY_AVAILABLE = True
except ImportError:
    print("Warning: Annoy not available. Install with ./build.sh from the annoy/ directory")
    ANNOY_AVAILABLE = False
    AnnoyIndex = None

try:
    import sys
    import os
    # Add the mrpt directory to Python path for import
    mrpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "mrpt")
    if os.path.exists(mrpt_path):
        sys.path.insert(0, mrpt_path)
    from mrpt import MRPTIndex
    MRPT_AVAILABLE = True
except ImportError:
    print("Warning: MRPT not available. Install with ./build.sh from the mrpt/ directory")
    MRPT_AVAILABLE = False
    MRPTIndex = None

US_IN_S = 1_000_000


def _get_results_cache_key(csv_paths: List[str], nb: int, nq: int, k: int) -> str:
    """Generate a cache key for ground truth results based on CSV paths and parameters"""
    # Sort and filter out None paths to ensure consistent key generation
    valid_paths = [p for p in csv_paths if p is not None]
    valid_paths.sort()  # Ensure consistent ordering
    
    # Extract just the basenames without extensions
    basenames = [os.path.splitext(os.path.basename(p))[0] for p in valid_paths]
    
    # Create the cache key
    paths_str = "_".join(basenames)
    return f"results_{paths_str}_nb{nb}_nq{nq}_k{k}.pkl"


def _load_cached_results(cache_key: str, cache_dir: str = "/tmp") -> Optional[Tuple[np.ndarray, np.ndarray]]:
    """Load cached ground truth results if they exist"""
    cache_file = os.path.join(cache_dir, cache_key)
    
    if os.path.exists(cache_file):
        try:
            print(f"Loading cached results from {cache_file}")
            with open(cache_file, 'rb') as f:
                D_real, I_real = pickle.load(f)
            print(f"Successfully loaded cached results with shape D: {D_real.shape}, I: {I_real.shape}")
            return D_real, I_real
        except Exception as e:
            print(f"Error loading cached results: {e}, will recompute...")
            # Remove corrupted cache file
            try:
                os.remove(cache_file)
            except:
                pass
    
    return None


def _save_cached_results(D_real: np.ndarray, I_real: np.ndarray, cache_key: str, cache_dir: str = "/tmp"):
    """Save ground truth results to cache"""
    cache_file = os.path.join(cache_dir, cache_key)
    
    try:
        print(f"Saving results to cache: {cache_file}")
        with open(cache_file, 'wb') as f:
            pickle.dump((D_real, I_real), f)
        print(f"Successfully cached results with shape D: {D_real.shape}, I: {I_real.shape}")
    except Exception as e:
        print(f"Warning: Failed to save results to cache: {e}")


@dataclass
class PerfCounters:
    wall_time_s: float = 0.0
    user_time_s: float = 0.0
    system_time_s: float = 0.0


def _get_annoy_cache_key(dataset_name: str, nb: int, d: int, index_type: str, epsilon: float, n_trees: int) -> str:
    """Generate a cache key based on dataset name and parameters"""
    # Use dataset name and nb instead of data hash
    param_str = f"{d}_{index_type}_{epsilon}_{n_trees}"
    return f"{index_type}_{dataset_name}_{nb}_{param_str}.ann"


def _get_hnsw_cache_key(dataset_name: str, nb: int, d: int, index_type: str, M_HNSW: int, efConstruction: int, epsilon: float) -> str:
    """Generate a cache key for HNSW indexes based on dataset name and parameters"""
    # Use dataset name and nb instead of data hash
    param_str = f"{d}_{index_type}_{M_HNSW}_{efConstruction}_{epsilon}"
    return f"{index_type}_{dataset_name}_{nb}_{param_str}.index"


def _get_or_build_annoy_index(
    dataset_name: str,
    nb: int,
    d: int, 
    index_type: str, 
    nlevels: int, 
    epsilon: float, 
    n_trees: int, 
    data_loader_func: Optional[callable] = None,
    cache_dir: Optional[str] = None
) -> Tuple[AnnoyIndex, bool]:
    """Get cached index or build new one"""
    if cache_dir is None:
        cache_dir = tempfile.gettempdir()
    
    # Ensure cache directory exists
    os.makedirs(cache_dir, exist_ok=True)
    
    cache_file = os.path.join(cache_dir, _get_annoy_cache_key(dataset_name, nb, d, index_type, epsilon, n_trees))
    metric = "euclidean_panorama" if index_type == "annoy_pano" else "euclidean"

    print(f"Cache file: {cache_file}")
    
    if os.path.exists(cache_file):
        print(f"Loading cached Annoy index from {cache_file}")
        index = AnnoyIndex(d, f"{metric} {nlevels} {epsilon}")
        try:
            if index.load(cache_file):
                print(f"Successfully loaded cached index with {index.get_n_items()} items and {index.get_n_trees()} trees")
                index.set_nlevels(nlevels)
                print(f"Successfully restructured index with {nlevels} levels")

                return index, True  # True indicates loaded from cache
            else:
                print("Failed to load cached index, rebuilding...")
                os.remove(cache_file)  # Remove corrupted cache
        except Exception as e:
            print(f"Error loading cached index: {e}, rebuilding...")
            if os.path.exists(cache_file):
                os.remove(cache_file)  # Remove corrupted cache
    
    # Cache miss - need to build new index
    print("Cache miss - loading data for new Annoy index...")
    if data_loader_func is None:
        raise ValueError("data_loader_func must be provided when cache miss occurs")
    
    xb = data_loader_func()
    
    # Build new index
    print("Creating new Annoy index...")
    index = AnnoyIndex(d, f"{metric} {nlevels} {epsilon}")
    
    # Add all items
    print("Adding all items to index...")
    for i, vec in enumerate(xb):
        index.add_item(i, vec)
    
    # Build the index
    print("Building index with 64 threads...")
    index.build(n_trees, n_jobs=64)
    
    # Save to cache
    try:
        index.save(cache_file)
        print(f"Saved Annoy index to cache: {cache_file}")
    except Exception as e:
        print(f"Warning: Failed to save index to cache: {e}")
    
    return index, False  # False indicates newly built


def _get_or_build_hnsw_index(
    dataset_name: str,
    nb: int,
    d: int, 
    M_HNSW: int,
    efConstruction: int,
    efSearch: int,
    nlevels: int, 
    epsilon: float, 
    index_type: str,
    data_loader_func: Optional[callable] = None,
    cache_dir: Optional[str] = None
) -> Tuple[faiss.Index, bool]:
    """Get cached HNSW index or build new one"""
    if cache_dir is None:
        cache_dir = tempfile.gettempdir()
    
    # Ensure cache directory exists
    os.makedirs(cache_dir, exist_ok=True)
    
    cache_file = os.path.join(cache_dir, _get_hnsw_cache_key(dataset_name, nb, d, index_type, M_HNSW, efConstruction, epsilon))

    print(f"HNSW Cache file: {cache_file}")
    
    if os.path.exists(cache_file):
        print(f"Loading cached HNSW index from {cache_file}")
        try:
            index = faiss.read_index(cache_file)
            print(f"Successfully loaded cached HNSW index with {index.ntotal} items")
            
            # Set the efSearch parameter (this doesn't need to be cached as it's a search-time parameter)
            if hasattr(index, 'hnsw'):
                index.hnsw.efSearch = efSearch
                index.hnsw.search_bounded_queue = True
            
            # For panorama indexes, set the nlevels if the method exists
            if index_type == "hnsw_pano":
                if hasattr(index, 'set_nlevels'):
                    index.set_nlevels(nlevels)
                    print(f"Successfully restructured HNSW index with {nlevels} levels")
                else:
                    print(f"Warning: Loaded index is {type(index)} and doesn't have set_nlevels method")
                    print(f"Index class: {index.__class__.__name__}")
                    # Force rebuild if we can't set levels
                    raise Exception(f"Cached index is not IndexHNSWPanorama, need to rebuild")

            
            return index, True  # True indicates loaded from cache
        except Exception as e:
            print(f"Error loading cached HNSW index: {e}, rebuilding...")
            if os.path.exists(cache_file):
                os.remove(cache_file)  # Remove corrupted cache
    
    # Cache miss - need to build new index
    print("Cache miss - loading data for new HNSW index...")
    if data_loader_func is None:
        raise ValueError("data_loader_func must be provided when cache miss occurs")
    
    xb = data_loader_func()
    
    # Build new index
    print("Creating new HNSW index...")
    if index_type == "hnsw":
        index = faiss.IndexHNSWFlat(d, M_HNSW)
        index.hnsw.efConstruction = efConstruction
        index.hnsw.efSearch = efSearch
        index.hnsw.search_bounded_queue = True
    elif index_type == "hnsw_pano":
        index = faiss.IndexHNSWPanorama(d, M_HNSW, nlevels, epsilon)
        index.hnsw.efConstruction = efConstruction
        index.hnsw.efSearch = efSearch
        index.hnsw.search_bounded_queue = True
    else:
        raise ValueError(f"Invalid HNSW index type: {index_type}")
    
    # Add all items to the index
    print("Adding all items to HNSW index...")
    index.add(xb)
    
    # Save to cache
    try:
        faiss.write_index(index, cache_file)
        print(f"Saved HNSW index to cache: {cache_file}")
    except Exception as e:
        print(f"Warning: Failed to save HNSW index to cache: {e}")
    
    return index, False  # False indicates newly built


class CSVDataset(Dataset):
    """Custom dataset class for loading data from CSV files and fvecs files."""
    
    def __init__(self, csv_path: str, nq: int, nb: Optional[int] = None, seed: int = 1338, query_indices: Optional[np.ndarray] = None, test_csv_path: Optional[str] = None):
        """
        Initialize dataset from CSV or fvecs file.
        
        Args:
            csv_path: Path to the CSV or fvecs file (training data)
            nq: Number of query points to sample (only used if test_csv_path is None)
            nb: Maximum number of database points to load (None = load all)
            seed: Random seed for reproducible sampling
            query_indices: Predetermined query indices to use (if None, will sample randomly)
            test_csv_path: Path to separate test file for queries (if provided, uses all points from this file as queries)
        """
        self.file_path = csv_path  # Can be CSV or fvecs
        self.test_file_path = test_csv_path
        self.nq = nq
        self.nb = nb
        self.seed = seed
        self.query_indices = query_indices
        self._load_data(read_base=False)
    
    def _read_fvecs(self, filename: str) -> np.ndarray:
        """Read fvecs format file."""
        with open(filename, 'rb') as f:
            vectors = []
            while True:
                # Read dimension (4 bytes, little endian)
                dim_bytes = f.read(4)
                if len(dim_bytes) != 4:
                    break
                dim = int.from_bytes(dim_bytes, byteorder='little')
                
                # Read the vector (dim * 4 bytes for floats)
                vector_bytes = f.read(dim * 4)
                if len(vector_bytes) != dim * 4:
                    break
                
                # Convert bytes to float32 array
                vector = np.frombuffer(vector_bytes, dtype=np.float32)
                vectors.append(vector)
                
                # Apply limit if specified
                if self.nb is not None and len(vectors) >= self.nb:
                    break
            
            return np.array(vectors, dtype=np.float32)

    def _load_data(self, read_base=True):
        """Load and preprocess data from CSV or fvecs files."""
        
        # Load training data (database)
        if read_base:
            print(f"Loading training data from {self.file_path}...")
            self.data = self._load_file_data(self.file_path, self.nb, "training")
        else:
            self.data = None
        
        # Load query data from separate test file if provided
        if self.test_file_path is not None:
            print(f"Loading test data from {self.test_file_path}...")
            self.query_data = self._load_file_data(self.test_file_path, self.nq, "test")
            print(f"Using all {len(self.query_data)} points from test file as queries")
            # No need for query_indices since we use all test data
            self.query_indices = None
        else:
            # Use predetermined query indices or sample new ones from training data
            if self.query_indices is not None:
                # Use provided query indices
                print(f"Using predetermined query indices: {len(self.query_indices)} points")
                # Validate indices are within bounds
                if np.max(self.query_indices) >= len(self.data):
                    raise ValueError(f"Query indices contain values >= dataset size ({len(self.data)})")
            else:
                # Set random seed for reproducible sampling
                np.random.seed(self.seed)
                
                # Sample query points
                if self.nq >= len(self.data):
                    print(f"Warning: nq ({self.nq}) >= dataset size ({len(self.data)}), using all data as queries")
                    self.query_indices = np.arange(len(self.data))
                else:
                    self.query_indices = np.random.choice(len(self.data), size=self.nq, replace=False)
                
                print(f"Selected {len(self.query_indices)} query points")
            
            # No separate query data, will use indices into training data
            self.query_data = None
    
    def _load_file_data(self, file_path: str, row_limit: Optional[int], data_type: str) -> np.ndarray:
        """Load data from a single file (CSV or fvecs)."""
        # Determine file type based on extension
        if file_path.lower().endswith('.fvec') or file_path.lower().endswith('.fvecs'):
            print(f"Detected fvecs file format for {data_type} data")
            # Temporarily store the row limit and restore it
            original_nb = self.nb
            self.nb = row_limit
            data = self._read_fvecs(file_path)
            self.nb = original_nb
            print(f"Loaded {data_type} fvecs with shape: {data.shape}")
        else:
            print("NON FVECS/FVEC FILE!!!")
            exit(1)
        return data
    
    def get_database(self) -> np.ndarray:
        """Return the full dataset as database vectors."""
        return self.data
    
    def get_queries(self) -> np.ndarray:
        """Return query vectors (either from separate test file or sampled from training data)."""
        if self.query_data is not None:
            # Use all data from separate test file
            return self.query_data
        else:
            # Use sampled points from training data
            return self.data[self.query_indices]
    
    def get_groundtruth(self, k: int = 10) -> np.ndarray:
        """Compute ground truth using brute force search."""
        print("Computing ground truth with brute force...")
        index = faiss.IndexFlatL2(self.data.shape[1])
        index.add(self.data)
        _, I = index.search(self.get_queries(), k)
        return I


@contextmanager
def timed_execution() -> Generator[PerfCounters, None, None]:
    pcounters = PerfCounters()
    wall_time_start = time.perf_counter()
    rusage_start = resource.getrusage(resource.RUSAGE_SELF)
    yield pcounters
    wall_time_end = time.perf_counter()
    rusage_end = resource.getrusage(resource.RUSAGE_SELF)
    pcounters.wall_time_s = wall_time_end - wall_time_start
    pcounters.user_time_s = rusage_end.ru_utime - rusage_start.ru_utime
    pcounters.system_time_s = rusage_end.ru_stime - rusage_start.ru_stime


def is_perf_counter(key: str) -> bool:
    return key.endswith("_time_us")


def accumulate_perf_counter(
    phase: str,
    t: PerfCounters,
    counters: Dict[str, int]
):
    counters[f"{phase}_wall_time_us"] = int(t.wall_time_s * US_IN_S)
    counters[f"{phase}_user_time_us"] = int(t.user_time_s * US_IN_S)


def run_on_dataset(
    ds: Dataset,
    nlist: int,
    M: int,
    nbits: int,
    nlevels: int,
    num_threads: int,
    efConstruction: int,
    efSearch: int,
    M_HNSW: int,
    num_add_iterations: int,
    num_search_iterations: int,
    recall_target: float,
    nprobe: int = 8,
    index_type: str = "ivfpq",
    k: int = 10,
    epsilon: float = 1,
    n_trees: int = 10,
    search_k: int = -1,
    batch_size: int = 128,
    use_cache: bool = False,
    dataset_name: str = "synthetic",
) -> Tuple[Dict[str, int], Tuple[np.ndarray, np.ndarray]]:
    xq = ds.get_queries()
    nq, d = xq.shape
    
    # For cached indexes, delay loading the database until needed
    if use_cache and index_type.startswith(("annoy", "hnsw")):
        # Trust the command line argument for nb
        nb = ds.nb  # Use default if somehow not set
        # We'll load xb later if needed (cache miss)
        xb = None
    else:
        # Load database immediately for non-cached indexes
        if ds.data is None:
            ds._load_data(read_base=True)
        xb = ds.get_database()
        nb, _ = xb.shape

    # pyre-ignore[16]: Module `faiss` has no attribute `omp_set_num_threads`.
    faiss.omp_set_num_threads(64)
    
    # Create index based on type
    if index_type == "ivfpq":
        if xb is None:
            ds._load_data(read_base=True)
            xb = ds.get_database()
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
        index.train(xb)
        index.nprobe = nprobe
    elif index_type == "ivfpq_pano":
        if xb is None:
            ds._load_data(read_base=True)
            xb = ds.get_database()
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFPQPanorama(quantizer, d, nlist, M, nbits, nlevels, epsilon)
        index.train(xb)
        index.nprobe = nprobe
    elif index_type == "ivf_flat":
        if xb is None:
            ds._load_data(read_base=True)
            xb = ds.get_database()
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFFlat(quantizer, d, nlist)
        index.train(xb)
        index.nprobe = nprobe
    elif index_type == "ivf_flat_pano":
        if xb is None:
            ds._load_data(read_base=True)
            xb = ds.get_database()
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFFlatPanorama(quantizer, d, nlist, nlevels, epsilon, batch_size)
        index.train(xb)
        index.nprobe = nprobe
    elif index_type == "fs":
        if xb is None:
            ds._load_data(read_base=True)
            xb = ds.get_database()
        quantizer = faiss.IndexFlatL2(d)
        index = faiss.IndexIVFPQFastScan(quantizer, d, nlist, M, 4)
        index.train(xb)
        index.nprobe = nprobe
    elif index_type == "naive_pano":
        index = faiss.IndexFlatL2Panorama(d, nlevels, epsilon, batch_size)
    elif index_type == "naive":
        index = faiss.IndexFlatL2(d)
    elif index_type == "hnsw":
        index = faiss.IndexHNSWFlat(d, M_HNSW)
        index.hnsw.efConstruction = efConstruction
        index.hnsw.efSearch = efSearch
        index.hnsw.search_bounded_queue = True
    elif index_type == "hnsw_pano":
        index = faiss.IndexHNSWPanorama(d, M_HNSW, nlevels, epsilon)
        index.hnsw.efConstruction = efConstruction
        index.hnsw.efSearch = efSearch
        index.hnsw.search_bounded_queue = True
    elif index_type == "annoy":
        if not ANNOY_AVAILABLE:
            raise ValueError("Annoy not available. Install with ./build.sh from the annoy/ directory")
        
        # Default to euclidean distance for simple "annoy" option
        index = AnnoyIndex(d, f"euclidean {nlevels} {epsilon}")
    elif index_type == "annoy_pano":
        if not ANNOY_AVAILABLE:
            raise ValueError("Annoy not available. Install with ./build.sh from the annoy/ directory")
        
        # Default to euclidean_panorama distance for Panorama option
        index = AnnoyIndex(d, f"euclidean_panorama {nlevels} {epsilon}")
        # Note: Annoy index creation will be handled differently than FAISS
    elif index_type == "mrpt" or index_type == "mrpt_pano":
        if not MRPT_AVAILABLE:
            raise ValueError("MRPT not available. Install with ./build.sh from the mrpt/ directory")
        else: # Needs data in constructor!
            with timed_execution() as t:
                # Load data if not already loaded
                if xb is None:
                    ds._load_data(read_base=True)
                    xb = ds.get_database()
                # Add all items to the index
                is_pano = index_type == "mrpt_pano"
                index = MRPTIndex(xb, nlevels, epsilon, is_pano)
                index.build_autotune_sample(recall_target, k)

                counters = {}
                accumulate_perf_counter("add", t, counters)
                counters["nb"] = nb
                counters["recall_target"] = recall_target
                
                print("Running search...")
                
                search_times = []
                D_list = []
                I_list = []
                
                for iteration in range(num_search_iterations):
                    with timed_execution() as t:
                        I_iter = []
                        D_iter = []
                        for query in xq:
                            neighbors = index.ann(query, k=k, return_distances=True)
                            I_iter.append(neighbors[0])
                            D_iter.append(neighbors[1])
                        index.print_times()
                        I_list.append(np.array(I_iter))
                        D_list.append(np.array(D_iter))
                    search_times.append(t.wall_time_s)
                
                # Calculate aggregate timing statistics
                total_time = sum(search_times)
                avg_time = total_time / num_search_iterations
                std_time = np.std(search_times) if num_search_iterations > 1 else 0.0
                
                # Store both total time (for compatibility) and per-iteration stats
                counters["search_wall_time_us"] = int(total_time * US_IN_S)
                counters["search_user_time_us"] = int(total_time * US_IN_S)  # Approximation
                counters["search_avg_time_us"] = int(avg_time * US_IN_S)
                counters["search_std_time_us"] = int(std_time * US_IN_S)
                counters["search_times_us"] = [int(t * US_IN_S) for t in search_times]
                
                # Use the last iteration's results
                I = I_list[-1]
                D = D_list[-1]
    else:
        raise ValueError(f"Invalid index type: {index_type}")
    
    # Handle add phase differently for Annoy vs FAISS
    if index_type.startswith("annoy"):
        if use_cache:
            # Use cached index if available
            with timed_execution() as t:
                def load_data():
                    if ds.data is None:
                        ds._load_data(read_base=True)
                    return ds.get_database()
                
                index, was_cached = _get_or_build_annoy_index(
                    dataset_name=dataset_name,
                    nb=nb,
                    d=d,
                    index_type=index_type,
                    nlevels=nlevels,
                    epsilon=epsilon,
                    n_trees=n_trees,
                    data_loader_func=load_data
                )
        else:
            # Original behavior without caching
            with timed_execution() as t:
                # Load data if not already loaded
                if xb is None:
                    ds._load_data(read_base=True)
                    xb = ds.get_database()
                # Add all items to the index
                for i, vec in enumerate(xb):
                    index.add_item(i, vec)
                
                # Build the index (this is the expensive part for Annoy)
                print("Building index with 64 threads...")
                index.build(n_trees, n_jobs=64)

        counters = {}
        accumulate_perf_counter("add", t, counters)
        counters["nb"] = nb
        counters["num_add_iterations"] = num_add_iterations
        counters["n_trees"] = n_trees
        
        print("Running search...")
        # Note: Annoy search is inherently single-threaded per query
        
        # Annoy search with per-iteration timing
        search_k_param = search_k if search_k > 0 else n_trees * k
        
        search_times = []
        D_list = []
        I_list = []
        
        for iteration in range(num_search_iterations):
            with timed_execution() as t:
                I_iter = []
                D_iter = []
                for query in xq:
                    neighbors = index.get_nns_by_vector(query.tolist(), k, 
                                                        search_k=search_k_param, 
                                                        include_distances=True)
                    I_iter.append(neighbors[0])
                    D_iter.append(neighbors[1])
                index.get_level_reduction()
                I_list.append(np.array(I_iter))
                D_list.append(np.array(D_iter))
            search_times.append(t.wall_time_s)
        
        # Calculate aggregate timing statistics
        total_time = sum(search_times)
        avg_time = total_time / num_search_iterations
        std_time = np.std(search_times) if num_search_iterations > 1 else 0.0
        
        # Store both total time (for compatibility) and per-iteration stats
        counters["search_wall_time_us"] = int(total_time * US_IN_S)
        counters["search_user_time_us"] = int(total_time * US_IN_S)  # Approximation
        counters["search_avg_time_us"] = int(avg_time * US_IN_S)
        counters["search_std_time_us"] = int(std_time * US_IN_S)
        counters["search_times_us"] = [int(t * US_IN_S) for t in search_times]
        
        # Use the last iteration's results
        I = I_list[-1]
        D = D_list[-1]
        
        counters["search_k"] = search_k_param
    elif index_type.startswith("hnsw"):
        faiss.omp_set_num_threads(64)
        print("BUILDING IN PARALLEL")
        # HNSW indexes with caching support
        if use_cache:
            # Use cached HNSW index if available
            with timed_execution() as t:
                def load_data():
                    if ds.data is None:
                        ds._load_data(read_base=True)
                    return ds.get_database()
                
                index, was_cached = _get_or_build_hnsw_index(
                    dataset_name=dataset_name,
                    nb=nb,
                    d=d,
                    M_HNSW=M_HNSW,
                    efConstruction=efConstruction,
                    efSearch=efSearch,
                    nlevels=nlevels,
                    epsilon=epsilon,
                    index_type=index_type,
                    data_loader_func=load_data
                )
        else:
            # Original behavior without caching for HNSW
            with timed_execution() as t:
                # Load data if not already loaded
                if xb is None:
                    ds._load_data(read_base=True)
                    xb = ds.get_database()
                # Add all items to the index
                index.add(xb)
        
        counters = {}
        accumulate_perf_counter("add", t, counters)
        counters["nb"] = nb
        counters["num_add_iterations"] = num_add_iterations
        
        faiss.omp_set_num_threads(1)
        print("Running search...")

        
        # HNSW search with per-iteration timing
        search_times = []
        for iteration in range(num_search_iterations):
            with timed_execution() as t:
                D, I = index.search(xq, k)
            search_times.append(t.wall_time_s)
        
        # Calculate aggregate timing statistics
        total_time = sum(search_times)
        avg_time = total_time / num_search_iterations
        std_time = np.std(search_times) if num_search_iterations > 1 else 0.0
        
        # Store both total time (for compatibility) and per-iteration stats
        counters["search_wall_time_us"] = int(total_time * US_IN_S)
        counters["search_user_time_us"] = int(total_time * US_IN_S)  # Approximation
        counters["search_avg_time_us"] = int(avg_time * US_IN_S)
        counters["search_std_time_us"] = int(std_time * US_IN_S)
        counters["search_times_us"] = [int(t * US_IN_S) for t in search_times]
    elif not index_type.startswith("mrpt"):
        # Other FAISS indexes (non-HNSW)
        with timed_execution() as t:
            # Load data if not already loaded
            if xb is None:
                ds._load_data(read_base=True)
                xb = ds.get_database()
            index.add(xb)
        counters = {}
        accumulate_perf_counter("add", t, counters)
        counters["nb"] = nb
        counters["num_add_iterations"] = num_add_iterations

        faiss.omp_set_num_threads(1)

        print("Running search...")

        # Collect timing for each search iteration
        search_times = []
        for iteration in range(num_search_iterations):
            with timed_execution() as t:
                D, I = index.search(xq, k)
            search_times.append(t.wall_time_s)
        
        # Calculate aggregate timing statistics
        total_time = sum(search_times)
        avg_time = total_time / num_search_iterations
        std_time = np.std(search_times) if num_search_iterations > 1 else 0.0
        
        # Store both total time (for compatibility) and per-iteration stats
        counters["search_wall_time_us"] = int(total_time * US_IN_S)
        counters["search_user_time_us"] = int(total_time * US_IN_S)  # Approximation
        counters["search_avg_time_us"] = int(avg_time * US_IN_S)
        counters["search_std_time_us"] = int(std_time * US_IN_S)
        counters["search_times_us"] = [int(t * US_IN_S) for t in search_times]
    counters["nq"] = nq
    counters["nprobe"] = nprobe
    counters["nlist"] = nlist
    counters["M"] = M
    counters["nbits"] = nbits
    counters["d"] = d
    counters["num_search_iterations"] = num_search_iterations
    counters["index_type"] = index_type

    return counters, (D, I)


def run_csv(
    csv_path: str,
    nq: int,
    nb: Optional[int],
    recall_target: float,
    nlist: int,
    M: int,
    nbits: int,
    nlevels: int,
    num_threads: int,
    efConstruction: int,
    efSearch: int,
    M_HNSW: int,
    num_add_iterations: int = 1,
    num_search_iterations: int = 10,
    nprobe: int = 8,
    index_type: str = "ivfpq",
    seed: int = 1338,
    k: int = 10,
    epsilon: float = 1,
    trans_csv_path: Optional[str] = None,
    query_indices: Optional[np.ndarray] = None,
    n_trees: int = 10,
    search_k: int = -1,
    csv_test_path: Optional[str] = None,
    trans_csv_test_path: Optional[str] = None,
    batch_size: int = 4096,
    use_cache: bool = False,
) -> Tuple[Dict[str, int], Tuple[np.ndarray, np.ndarray]]:
    """Run benchmark on CSV dataset."""
    # Determine which test file to use for queries (prioritize transformed test data)
    test_file_for_queries = trans_csv_test_path if trans_csv_test_path is not None else csv_test_path
    
    # Generate dataset name from file path
    primary_file = trans_csv_path if trans_csv_path is not None else csv_path
    dataset_name = os.path.splitext(os.path.basename(primary_file))[0]
    
    # Determine which training file to use for database (prioritize transformed training data) 
    # For cacheable indexes (annoy/hnsw), only load base data if cache is disabled
    read_base_data = not (use_cache and index_type.startswith(("annoy", "hnsw")))
    
    if trans_csv_path is not None:
        # Use transformed data for database
        ds = CSVDataset(csv_path=trans_csv_path, nq=nq, nb=nb, seed=seed, 
                       query_indices=query_indices, test_csv_path=test_file_for_queries)
        if not read_base_data:
            print("Delaying base dataset loading for cached index...")
            ds._load_data(read_base=False)
    else:
        # Use original data for database
        ds = CSVDataset(csv_path=csv_path, nq=nq, nb=nb, seed=seed, 
                       query_indices=query_indices, test_csv_path=test_file_for_queries)
        if not read_base_data:
            print("Delaying base dataset loading for cached index...")
            ds._load_data(read_base=False)
    
    return run_on_dataset(
        ds,
        nlist=nlist,
        M=M,
        nbits=nbits,
        nlevels=nlevels,
        num_add_iterations=num_add_iterations,
        num_search_iterations=num_search_iterations,
        recall_target=recall_target,
        num_threads=num_threads,
        efConstruction=efConstruction,
        efSearch=efSearch,
        M_HNSW=M_HNSW,
        nprobe=nprobe,
        index_type=index_type,
        k=k,
        epsilon=epsilon,
        n_trees=n_trees,
        search_k=search_k,
        batch_size=batch_size,
        use_cache=use_cache,
        dataset_name=dataset_name,
    )


def run_synthetic(
    d: int,
    nb: int,
    nq: int,
    nlist: int,
    recall_target: float,
    M: int,
    nbits: int,
    nlevels: int,
    num_threads: int,
    efConstruction: int,
    efSearch: int,
    M_HNSW: int,
    num_add_iterations: int = 1,
    num_search_iterations: int = 10,
    nprobe: int = 8,
    index_type: str = "ivfpq",
    k: int = 10,
    epsilon: float = 1,
    n_trees: int = 10,
    search_k: int = -1,
    batch_size: int = 4096,
    use_cache: bool = False,
) -> Tuple[Dict[str, int], Tuple[np.ndarray, np.ndarray]]:
    """Run benchmark on synthetic dataset."""
    ds = SyntheticDataset(d=d, nb=nb, nt=0, nq=nq, metric="L2", seed=1338)
    dataset_name = f"synthetic_d{d}_nb{nb}_nq{nq}"
    return run_on_dataset(
        ds,
        nlist=nlist,
        M=M,
        nbits=nbits,
        nlevels=nlevels,
        num_add_iterations=num_add_iterations,
        num_search_iterations=num_search_iterations,
        recall_target=recall_target,
        num_threads=num_threads,
        efConstruction=efConstruction,
        efSearch=efSearch,
        M_HNSW=M_HNSW,
        nprobe=nprobe,
        index_type=index_type,
        k=k,
        epsilon=epsilon,
        n_trees=n_trees,
        search_k=search_k,
        batch_size=batch_size,
        use_cache=use_cache,
        dataset_name=dataset_name,
    )


def _accumulate_counters(
    element: Dict[str, int], accu: Optional[Dict[str, List[int]]] = None
) -> Dict[str, List[int]]:
    if accu is None:
        accu = {key: [value] for key, value in element.items()}
        return accu
    else:
        assert accu.keys() <= element.keys(), (
            "Accu keys must be a subset of element keys: "
            f"{accu.keys()} not a subset of {element.keys()}"
        )
        for key in accu.keys():
            accu[key].append(element[key])
        return accu


def main():
    parser = argparse.ArgumentParser(description="Benchmark IVFPQ")
    parser.add_argument("--nlist", type=int, default=100)
    parser.add_argument("--M", type=int, default=8)
    parser.add_argument("--nbits", type=int, default=8)
    parser.add_argument("--num-threads", type=int, default=1)
    parser.add_argument("--nlevels", type=int, default=8)
    parser.add_argument("--warm-up-iterations", type=int, default=0)
    parser.add_argument("--num-search-iterations", type=int, default=5)
    parser.add_argument("--num-add-iterations", type=int, default=1)
    parser.add_argument("--num-repetitions", type=int, default=1)
    parser.add_argument("--nprobe", type=int, default=8)
    # Build choices list dynamically based on availability
    index_choices = ["ivfpq", "ivfpq_pano", "ivf_flat", "ivf_flat_pano", "fs", "naive_pano", "naive", "hnsw", "hnsw_pano"]
    if ANNOY_AVAILABLE:
        index_choices.extend(["annoy", "annoy_pano"])
    if MRPT_AVAILABLE:
        index_choices.extend(["mrpt", "mrpt_pano"])
    
    parser.add_argument("--index-type", type=str, default="ivfpq", 
                        choices=index_choices,
                        help="Index type to benchmark")
    parser.add_argument("--seed", type=int, default=1338,
                        help="Random seed for query sampling")
    parser.add_argument("--k", type=int, default=10,
                        help="Number of nearest neighbors to retrieve (K)")
    parser.add_argument("--epsilon", type=float, default=1, help="pruning aggressiveness")

    # CSV/fvecs or synthetic data options
    parser.add_argument("--csv-path", type=str, default=None,
                        help="Path to CSV or fvecs file. If provided, will use file data instead of synthetic")
    parser.add_argument("--trans-csv-path", type=str, default=None,
                        help="Path to transformed CSV or fvecs file. If provided, will use this for the main index and --csv-path for validation")
    parser.add_argument("--csv-test-path", type=str, default=None,
                        help="Path to CSV or fvecs test file. If provided, will use these points as queries instead of sampling from training data")
    parser.add_argument("--trans-csv-test-path", type=str, default=None,
                        help="Path to transformed CSV or fvecs test file. If provided, will use these points as queries instead of sampling from training data")
    
    # Data parameters
    parser.add_argument("--nb", type=int, default=5000,
                        help="Number of database points (limits file rows/vectors when using file data)")
    parser.add_argument("--nq", type=int, default=500,
                        help="Number of query points")
    parser.add_argument("--d", type=int, default=128,
                        help="Dimension (only used for synthetic data)")

    # HNSW parameters
    parser.add_argument("--efConstruction", type=int, default=40,
                        help="efConstruction parameter for HNSW")
    parser.add_argument("--efSearch", type=int, default=16,
                        help="efSearch parameter for HNSW")
    parser.add_argument("--M-HNSW", type=int, default=32,
                        help="M parameter for HNSW")
    
    # Annoy parameters
    parser.add_argument("--n-trees", type=int, default=200,
                        help="Number of trees for Annoy index (more trees = better accuracy, larger index)")
    parser.add_argument("--search-k", type=int, default=-1,
                        help="Search-k parameter for Annoy (-1 for default: n_trees * k)")
    parser.add_argument("--batch-size", type=int, default=128,
                        help="Batch size for IndexFlatL2Panorama (default: 128)")
    
    # MRPT parameters
    parser.add_argument("--recall-target", type=float, default=0.9,
                        help="Recall target for MRPT index")

    # Caching option
    parser.add_argument("--cached", action="store_true",
                        help="Enable caching for Annoy and HNSW indexes to speed up repeated runs")

    parser.add_argument("--cache-results", action="store_true",
                        help="Enable caching for naive results")
    
    args = parser.parse_args()

    # Validate arguments
    if args.trans_csv_path is not None and args.csv_path is None:
        parser.error("--trans-csv-path requires --csv-path to be specified for validation")

    # Determine if using CSV or synthetic data
    use_csv = args.csv_path is not None
    
    # Pre-select query indices when using transformed data but no test files (for consistency)
    query_indices = None
    if use_csv and args.trans_csv_path is not None and args.csv_test_path is None and args.trans_csv_test_path is None:
        print("Pre-selecting query indices to ensure consistency between transformed and original datasets...")
        # Load original data to determine valid query indices
        temp_ds = CSVDataset(csv_path=args.csv_path, nq=args.nq, nb=args.nb, seed=args.seed)
        query_indices = temp_ds.query_indices
        print(f"Selected {len(query_indices)} query indices: {query_indices[:10]}{'...' if len(query_indices) > 10 else ''}")
    elif use_csv and (args.csv_test_path is not None or args.trans_csv_test_path is not None):
        print("Using separate test files for queries - no need to pre-select query indices")
    
    if use_csv:
        if args.trans_csv_path is not None:
            print(f"Using transformed data for index: {args.trans_csv_path}")
            print(f"Using original data for ground truth validation: {args.csv_path}")
        else:
            print(f"Using data from: {args.csv_path}")
        
        # Print test file information
        if args.trans_csv_test_path is not None:
            print(f"Using transformed test data for queries: {args.trans_csv_test_path}")
        elif args.csv_test_path is not None:
            print(f"Using test data for queries: {args.csv_test_path}")
        else:
            print(f"Will sample {args.nq} query points from training data")
        
        print(f"Will limit to {args.nb} database points")
    else:
        print(f"Using synthetic data: nb={args.nb}, nq={args.nq}, d={args.d}")

    print("Warming up...")

    if args.warm_up_iterations > 0:
        print(f"Warming up for {args.warm_up_iterations} iterations...")
        # warm-up
        if use_csv:
            run_csv(
                csv_path=args.csv_path,
                nq=args.nq,
                nb=args.nb,
                recall_target=args.recall_target,
                num_search_iterations=args.warm_up_iterations,
                num_add_iterations=args.warm_up_iterations,
                nlist=args.nlist,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type=args.index_type,
                seed=args.seed,
                k=args.k,
                epsilon=args.epsilon,
                trans_csv_path=args.trans_csv_path,
                query_indices=query_indices,
                n_trees=args.n_trees,
                search_k=args.search_k,
                csv_test_path=args.csv_test_path,
                trans_csv_test_path=args.trans_csv_test_path,
                batch_size=args.batch_size,
                use_cache=args.cached,
            )
        else:
            run_synthetic(
                num_search_iterations=args.warm_up_iterations,
                num_add_iterations=args.warm_up_iterations,
                d=args.d,
                nb=args.nb,
                nq=args.nq,
                nlist=args.nlist,
                recall_target=args.recall_target,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type=args.index_type,
                k=args.k,
                epsilon=args.epsilon,
                n_trees=args.n_trees,
                search_k=args.search_k,
                batch_size=args.batch_size,
                use_cache=args.cached,
            )
    
    if use_csv:
        print(f"Running benchmark with data from {args.csv_path} (limited to {args.nb} points)")
    else:
        print(
            f"Running benchmark with dataset(nb={args.nb}, nq={args.nq}, "
            f"d={args.d}), nlist={args.nlist}, M={args.M}, nbits={args.nbits}, "
            f"num_threads={args.num_threads}, nprobe={args.nprobe}, "
            f"index_type={args.index_type}..."
        )
    
    result = None
    for _ in range(args.num_repetitions):
        if use_csv:
            counters, (D, I) = run_csv(
                csv_path=args.csv_path,
                nq=args.nq,
                nb=args.nb,
                recall_target=args.recall_target,
                num_search_iterations=args.num_search_iterations,
                num_add_iterations=args.num_add_iterations,
                nlist=args.nlist,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type=args.index_type,
                seed=args.seed,
                k=args.k,
                epsilon=args.epsilon,
                trans_csv_path=args.trans_csv_path,
                query_indices=query_indices,
                n_trees=args.n_trees,
                search_k=args.search_k,
                csv_test_path=args.csv_test_path,
                trans_csv_test_path=args.trans_csv_test_path,
                batch_size=args.batch_size,
                use_cache=args.cached,
            )
        else:
            counters, (D, I) = run_synthetic(
                num_search_iterations=args.num_search_iterations,
                num_add_iterations=args.num_add_iterations,
                d=args.d,
                nb=args.nb,
                nq=args.nq,
                nlist=args.nlist,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type=args.index_type,
                k=args.k,
                epsilon=args.epsilon,
                n_trees=args.n_trees,
                search_k=args.search_k,
                batch_size=args.batch_size,
                use_cache=args.cached,
            )
        result = _accumulate_counters(counters, result)

    # Get ground truth for recall calculation (always use original data, not transformed)
    faiss.omp_set_num_threads(64)
    print("NAIVE IN PARALLEL")
    
    D_real, I_real = None, None
    
    if use_csv:
        # For CSV data, try to load cached results first if caching is enabled
        if args.cache_results:
            # Determine which CSV paths are used for ground truth calculation
            ground_truth_paths = [args.csv_path]  # Always use original csv_path
            if args.csv_test_path is not None:
                ground_truth_paths.append(args.csv_test_path)  # Add test path if provided
            
            cache_key = _get_results_cache_key(ground_truth_paths, args.nb, args.nq, args.k)
            cached_results = _load_cached_results(cache_key)
            
            if cached_results is not None:
                print(f"Sucessfully loaded cached results from {cache_key}")
                D_real, I_real = cached_results
        
        # If not cached or caching disabled, compute ground truth
        if D_real is None or I_real is None:
            print("Computing ground truth with naive index...")
            _, (D_real, I_real) = run_csv(
                csv_path=args.csv_path,
                nq=args.nq,
                nb=args.nb,
                recall_target=args.recall_target,
                num_search_iterations=1,
                num_add_iterations=1,
                nlist=args.nlist,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type="naive",
                seed=args.seed,
                k=args.k,
                epsilon=args.epsilon,
                trans_csv_path=None,  # Always use original data for ground truth
                query_indices=query_indices,
                n_trees=args.n_trees,
                search_k=args.search_k,
                csv_test_path=args.csv_test_path,  # Use same test queries for ground truth
                trans_csv_test_path=None,  # But don't use transformed test data for ground truth
                use_cache=False,  # Don't use cache for ground truth calculation
            )
            
            # Save results to cache if caching is enabled
            if args.cache_results:
                _save_cached_results(D_real, I_real, cache_key)
    else:
        # For synthetic data, try to load cached results first if caching is enabled
        if args.cache_results:
            # For synthetic data, use a synthetic identifier
            synthetic_paths = [f"synthetic_d{args.d}"]
            cache_key = _get_results_cache_key(synthetic_paths, args.nb, args.nq, args.k)
            cached_results = _load_cached_results(cache_key)
            
            if cached_results is not None:
                D_real, I_real = cached_results
        
        # If not cached or caching disabled, compute ground truth
        if D_real is None or I_real is None:
            print("Computing ground truth with naive index...")
            _, (D_real, I_real) = run_synthetic(
                num_search_iterations=1,
                num_add_iterations=1,
                d=args.d,
                nb=args.nb,
                nq=args.nq,
                nlist=args.nlist,
                recall_target=args.recall_target,
                M=args.M,
                nbits=args.nbits,
                nlevels=args.nlevels,
                num_threads=args.num_threads,
                efConstruction=args.efConstruction,
                efSearch=args.efSearch,
                M_HNSW=args.M_HNSW,
                nprobe=args.nprobe,
                index_type="naive",
                k=args.k,
                epsilon=args.epsilon,
                n_trees=args.n_trees,
                search_k=args.search_k,
                use_cache=False,  # Don't use cache for ground truth calculation
            )
            
            # Save results to cache if caching is enabled
            if args.cache_results:
                print(f"Saving ground truth results to cache at {cache_key}")
                _save_cached_results(D_real, I_real, cache_key)

    assert result is not None
    
    # Extract key metrics
    search_wall_time_us = result.get("search_wall_time_us", [0])
    search_avg_time_us = result.get("search_avg_time_us", [0])
    search_std_time_us = result.get("search_std_time_us", [0])
    search_times_us = result.get("search_times_us", [[0]])  # List of lists
    nq = result.get("nq", [0])[0]  # Number of queries is constant across runs
    
    # Calculate QPS and timing stats from average times
    search_avg_time_s = np.array(search_avg_time_us) / US_IN_S
    search_std_time_s = np.array(search_std_time_us) / US_IN_S
    
    # Calculate QPS for each repetition's average time
    qps_avg = nq / search_avg_time_s
    
    # Also calculate QPS for individual iterations (across all repetitions)
    all_iteration_times = []
    for rep_times in search_times_us:
        all_iteration_times.extend([t / US_IN_S for t in rep_times])
    
    if all_iteration_times:
        qps_per_iteration = [nq / t for t in all_iteration_times]
        qps_mean = np.mean(qps_per_iteration)
        qps_std = np.std(qps_per_iteration)
    else:
        # Fallback to old calculation
        search_wall_time_s = np.array(search_wall_time_us) / US_IN_S
        qps_per_iteration = nq / search_wall_time_s
        qps_mean = np.mean(qps_per_iteration)
        qps_std = 0.0

    # Calculate recall against ground truth (set intersection, order-independent)
    recall = np.mean([len(set(I_real[i]) & set(I[i])) / args.k for i in range(len(I_real))])
    
    print(f"Search Performance (k={args.k}):")
    print(f"  Average time per search: {np.mean(search_avg_time_s)*1000:.2f} ± {np.mean(search_std_time_s)*1000:.2f} ms")
    print(f"  QPS: {qps_mean:.1f} ± {qps_std:.1f}")
    print(f"  Total iterations: {len(all_iteration_times) if all_iteration_times else args.num_search_iterations}")
    print(f"  Queries per iteration: {nq}")
    print(f"  Recall: {recall:.4f}")

if __name__ == "__main__":
    main()
