# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import resource
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Generator, List, Optional, Tuple

import faiss  # @manual=//faiss/python:pyfaiss
import numpy as np
from faiss.contrib.datasets import (  # @manual=//faiss/contrib:faiss_contrib
    Dataset,
    SyntheticDataset,
)

US_IN_S = 1_000_000


@dataclass
class PerfCounters:
    wall_time_s: float = 0.0
    user_time_s: float = 0.0
    system_time_s: float = 0.0


@contextmanager
def timed_execution() -> Generator[PerfCounters, None, None]:
    pcounters = PerfCounters()
    wall_time_start = time.perf_counter()
    rusage_start = resource.getrusage(resource.RUSAGE_SELF)
    yield pcounters
    wall_time_end = time.perf_counter()
    rusage_end = resource.getrusage(resource.RUSAGE_SELF)
    pcounters.wall_time_s = wall_time_end - wall_time_start
    pcounters.user_time_s = rusage_end.ru_utime - rusage_start.ru_utime
    pcounters.system_time_s = rusage_end.ru_stime - rusage_start.ru_stime


class CSVDataset(Dataset):
    """A dataset that loads vectors from a CSV file."""

    def __init__(self, csv_file: str, nb: int, nq: int, nc: Optional[int] = None):
        """
        Initialize CSV dataset.
        
        Args:
            csv_file: Path to the CSV file containing vectors
            nb: Number of database vectors
            nq: Number of query vectors
            nc: Number of columns to read (default: None, reads all columns)
        """
        Dataset.__init__(self)
        
        # Try pandas first, fall back to manual CSV reading
        try:
            import pandas as pd
            print(f"Loading data from CSV file: {csv_file} (reading up to {nb + nq} rows)")
            
            if nc is not None:
                print(f"Reading first {nc} columns")
                df = pd.read_csv(csv_file, nrows=nb + nq, usecols=list(range(nc)))
            else:
                df = pd.read_csv(csv_file, nrows=nb + nq)
            
            # Use all numeric columns as features
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            if len(numeric_cols) == 0:
                raise ValueError("No numeric columns found in CSV file")
            
            data = df[numeric_cols].values.astype('float32')
        except ImportError:
            # Fall back to manual CSV reading
            print(f"Loading data from CSV file: {csv_file} (using built-in CSV reader)")
            import csv
            
            with open(csv_file, 'r') as f:
                reader = csv.reader(f)
                rows = []
                for row in reader:
                    try:
                        # Try to convert all values to float
                        if nc is not None:
                            float_row = [float(x) for x in row[:nc]]
                        else:
                            float_row = [float(x) for x in row]
                        rows.append(float_row)
                    except ValueError:
                        # Skip rows that can't be converted (e.g., header rows)
                        continue
                
                if not rows:
                    raise ValueError("No valid numeric data found in CSV file")
                
                data = np.array(rows, dtype='float32')
        print(f"Loaded {data.shape[0]} vectors of dimension {data.shape[1]}")
        
        # Check if we have enough data
        if data.shape[0] < nb + nq:
            raise ValueError(f"Not enough data points. Need at least {nb + nq} vectors, got {data.shape[0]}")
        
        # Set dataset properties
        self.d = data.shape[1]
        self.metric = 'L2'
        self.nt = 0  # No training data needed
        self.nb = nb
        self.nq = nq
        
        # Split the data just like SyntheticDataset does
        self.xb = data[:nb]  # First nb vectors for database
        self.xq = data[nb:nb + nq]  # Next nq vectors for queries
        
        print(f"Dataset split: {nb} database, {nq} query vectors")

    def get_queries(self):
        return self.xq

    def get_train(self, maxtrain=None):
        return np.empty((0, self.d), dtype='float32')

    def get_database(self):
        return self.xb

    def get_groundtruth(self, k=100):
        from faiss.contrib.exhaustive_search import knn
        return knn(
            self.xq, self.xb, k,
            faiss.METRIC_L2 if self.metric == 'L2' else faiss.METRIC_INNER_PRODUCT
        )[1]


def is_perf_counter(key: str) -> bool:
    return key.endswith("_time_us")


def accumulate_perf_counter(
    phase: str,
    t: PerfCounters,
    counters: Dict[str, int]
):
    counters[f"{phase}_wall_time_us"] = int(t.wall_time_s * US_IN_S)
    counters[f"{phase}_user_time_us"] = int(t.user_time_s * US_IN_S)


def run_on_dataset(
    ds: Dataset,
    M: int,
    num_threads: int,
    num_add_iterations: int,
    num_search_iterations: int,
    efSearch: int = 16,
    efConstruction: int = 40,
    search_bounded_queue: bool = True,
    index_type: str = "panorama",
    only_query_idx: Optional[int] = None,
    levels: int = 16,
    epsilon: float = 1.0,
    return_index: bool = False,
) -> Tuple[Dict[str, int], Optional[object]]:
    xq = ds.get_queries()
    xb = ds.get_database()
    nb, d = xb.shape
    nq, d = xq.shape

    # If only_query_idx is specified, filter to just that query
    if only_query_idx is not None:
        if only_query_idx < 0 or only_query_idx >= nq:
            raise ValueError(f"Query index {only_query_idx} is out of range [0, {nq-1}]")
        xq = xq[only_query_idx:only_query_idx+1]
        nq = 1
        print(f"Running only query {only_query_idx}")

    k = 10
    # pyre-ignore[16]: Module `faiss` has no attribute `omp_set_num_threads`.
    faiss.omp_set_num_threads(num_threads)
    
    # Create index based on type
    if index_type == "panorama":
        index = faiss.IndexHNSWPanorama(d, M, levels, epsilon)
    else:  # "flat"
        index = faiss.IndexHNSWFlat(d, M)
    
    index.hnsw.efConstruction = efConstruction  # default
    with timed_execution() as t:
        for _ in range(num_add_iterations):
            index.add(xb)
    counters = {}
    accumulate_perf_counter("add", t, counters)
    counters["nb"] = nb
    counters["num_add_iterations"] = num_add_iterations

    index.hnsw.efSearch = efSearch
    index.hnsw.search_bounded_queue = search_bounded_queue
    with timed_execution() as t:
        for _ in range(num_search_iterations):
            D, I = index.search(xq, k)
    accumulate_perf_counter("search", t, counters)
    counters["nq"] = nq
    counters["efSearch"] = efSearch
    counters["efConstruction"] = efConstruction
    counters["M"] = M
    counters["d"] = d
    counters["num_search_iterations"] = num_search_iterations
    counters["index_type"] = index_type

    if return_index:
        return counters, index
    else:
        return counters, None


def compare_indices(
    ds: Dataset,
    M: int,
    num_threads: int,
    num_add_iterations: int,
    num_search_iterations: int,
    efSearch: int = 16,
    efConstruction: int = 40,
    search_bounded_queue: bool = True,
    only_query_idx: Optional[int] = None,
    levels: int = 16,
    epsilon: float = 1.0,
) -> None:
    """Run both IndexHNSWPanorama and IndexHNSWFlat and compare results."""
    
    print("\nRunning IndexHNSWFlat...")
    flat_counters, flat_index = run_on_dataset(
        ds, M, num_threads, num_add_iterations, num_search_iterations,
        efSearch, efConstruction, search_bounded_queue, "flat", only_query_idx,
        levels, epsilon, return_index=True
    )
    
    print("\nRunning IndexHNSWPanorama...")
    panorama_counters, panorama_index = run_on_dataset(
        ds, M, num_threads, num_add_iterations, num_search_iterations,
        efSearch, efConstruction, search_bounded_queue, "panorama", only_query_idx,
        levels, epsilon, return_index=True
    )
    
    print("\nPerformance Comparison:")
    print(f"Add time - Panorama: {panorama_counters['add_wall_time_us']/1000:.1f}ms, Flat: {flat_counters['add_wall_time_us']/1000:.1f}ms")
    print(f"Search time - Panorama: {panorama_counters['search_wall_time_us']/1000:.1f}ms, Flat: {flat_counters['search_wall_time_us']/1000:.1f}ms")
    print(f"Speedup - Panorama: {flat_counters['search_wall_time_us']/panorama_counters['search_wall_time_us']:.3f}x")
    
    # Compare search results using the already built indices
    print("\nComparing search results...")
    compare_search_results_with_indices(ds, flat_index, panorama_index, only_query_idx)


def compare_search_results_with_indices(
    ds: Dataset,
    flat_index: object,
    panorama_index: object,
    only_query_idx: Optional[int] = None,
) -> None:
    """Compare search results between pre-built IndexHNSWFlat and IndexHNSWPanorama."""
    
    xq = ds.get_queries()
    xb = ds.get_database()
    
    nb, d = xb.shape
    nq, d = xq.shape
    k = 10
    
    # If only_query_idx is specified, filter to just that query
    if only_query_idx is not None:
        if only_query_idx < 0 or only_query_idx >= nq:
            raise ValueError(f"Query index {only_query_idx} is out of range [0, {nq-1}]")
        xq = xq[only_query_idx:only_query_idx+1]
        nq = 1
        print(f"Comparing only query {only_query_idx}")
    
    # Create real (exact) index for ground truth
    real_index = faiss.IndexFlatL2(d)
    real_index.add(xb)
    
    # Perform search on all indices
    real_distances, real_labels = real_index.search(xq, k)
    flat_distances, flat_labels = flat_index.search(xq, k)
    panorama_distances, panorama_labels = panorama_index.search(xq, k)
    
    # Compare results
    print(f"Comparing {nq} queries with k={k}")
    
    # Check if shapes match
    if panorama_distances.shape != real_distances.shape:
        print(f"❌ Shape mismatch: Panorama {panorama_distances.shape} vs Real {real_distances.shape}")
        return
    
    if real_labels.shape != panorama_labels.shape:
        print(f"❌ Label shape mismatch: Real {real_labels.shape} vs Panorama {panorama_labels.shape}")
        return
    
    def compute_metrics(ref_distances, ref_labels, test_distances, test_labels, k):
        nq = ref_distances.shape[0]
        exact_matches = 0
        label_overlaps = 0
        distance_matches = 0
        total_comparisons = nq * k
        recall_at_k = []
        relative_distance_error = []
        for i in range(nq):
            ref_label_set = set(ref_labels[i])
            test_label_set = set(test_labels[i])
            overlap = len(ref_label_set & test_label_set)
            label_overlaps += overlap
            recall = overlap / k
            recall_at_k.append(recall)
            for j in range(k):
                if (abs(ref_distances[i, j] - test_distances[i, j]) < 1e-6 and
                    ref_labels[i, j] == test_labels[i, j]):
                    exact_matches += 1
                min_distance_diff = float('inf')
                for p in range(k):
                    diff = abs(ref_distances[i, j] - test_distances[i, p])
                    if diff < min_distance_diff:
                        min_distance_diff = diff
                if min_distance_diff < 1e-6:
                    distance_matches += 1
                if ref_distances[i, j] != 0:
                    rel_error = min_distance_diff / ref_distances[i, j]
                    relative_distance_error.append(rel_error)
        exact_match_rate = exact_matches / total_comparisons * 100
        label_overlap_rate = label_overlaps / total_comparisons * 100
        distance_match_rate = distance_matches / total_comparisons * 100
        mean_recall = np.mean(recall_at_k) * 100
        mean_rel_error = np.mean(relative_distance_error) * 100 if relative_distance_error else 0
        return dict(
            exact_match_rate=exact_match_rate,
            label_overlap_rate=label_overlap_rate,
            distance_match_rate=distance_match_rate,
            mean_recall=mean_recall,
            mean_rel_error=mean_rel_error,
            recall_at_k=recall_at_k,
            relative_distance_error=relative_distance_error,
        )

    print("\nFlat vs Real:")
    flat_metrics = compute_metrics(real_distances, real_labels, flat_distances, flat_labels, k)
    print(f"Mean recall@{k}: {flat_metrics['mean_recall']:.2f}%")

    print("\nPanorama vs Real:")
    panorama_metrics = compute_metrics(real_distances, real_labels, panorama_distances, panorama_labels, k)
    print(f"Mean recall@{k}: {panorama_metrics['mean_recall']:.2f}%")


def compare_search_results(
    ds: Dataset,
    M: int,
    efSearch: int = 16,
    efConstruction: int = 40,
    search_bounded_queue: bool = True,
    only_query_idx: Optional[int] = None,
    levels: int = 16,
    epsilon: float = 1.0,
) -> None:
    """Compare search results between IndexHNSWFlat and IndexHNSWPanorama."""
    
    xq = ds.get_queries()
    xb = ds.get_database()
    
    nb, d = xb.shape
    nq, d = xq.shape
    k = 10
    
    # If only_query_idx is specified, filter to just that query
    if only_query_idx is not None:
        if only_query_idx < 0 or only_query_idx >= nq:
            raise ValueError(f"Query index {only_query_idx} is out of range [0, {nq-1}]")
        xq = xq[only_query_idx:only_query_idx+1]
        nq = 1
        print(f"Comparing only query {only_query_idx}")
    
    # Create both indices
    real_index = faiss.IndexFlatL2(d)
    flat_index = faiss.IndexHNSWFlat(d, M)
    panorama_index = faiss.IndexHNSWPanorama(d, M, levels, epsilon)
    
    # Set parameters
    flat_index.hnsw.efConstruction = efConstruction
    flat_index.hnsw.efSearch = efSearch
    flat_index.hnsw.search_bounded_queue = search_bounded_queue

    panorama_index.hnsw.efConstruction = efConstruction
    panorama_index.hnsw.efSearch = efSearch
    panorama_index.hnsw.search_bounded_queue = search_bounded_queue
    
    # Add data to both indices
    real_index.add(xb)
    flat_index.add(xb)
    panorama_index.add(xb)
    
    # Perform search on both
    real_distances, real_labels = real_index.search(xq, k)
    flat_distances, flat_labels = flat_index.search(xq, k)
    panorama_distances, panorama_labels = panorama_index.search(xq, k)
    
    # Compare results
    print(f"Comparing {nq} queries with k={k}")
    
    # Check if shapes match
    if panorama_distances.shape != real_distances.shape:
        print(f"❌ Shape mismatch: Panorama {panorama_distances.shape} vs Real {real_distances.shape}")
        return
    
    if real_labels.shape != panorama_labels.shape:
        print(f"❌ Label shape mismatch: Real {real_labels.shape} vs Panorama {panorama_labels.shape}")
        return
    
    # Compare results allowing for different ordering
    exact_matches = 0  # Exact matches (same distance and label at same position)
    label_overlaps = 0  # Labels that appear in both results for a query
    distance_matches = 0  # Distances that match within tolerance
    total_comparisons = nq * k
    
    # For each query, track metrics
    recall_at_k = []  # What fraction of Flat's results are found by Panorama
    relative_distance_error = []  # How much do distances differ on average
    
    # Compare both Flat and Panorama vs Real
    def compute_metrics(ref_distances, ref_labels, test_distances, test_labels, k):
        nq = ref_distances.shape[0]
        exact_matches = 0
        label_overlaps = 0
        distance_matches = 0
        total_comparisons = nq * k
        recall_at_k = []
        relative_distance_error = []
        for i in range(nq):
            ref_label_set = set(ref_labels[i])
            test_label_set = set(test_labels[i])
            overlap = len(ref_label_set & test_label_set)
            label_overlaps += overlap
            recall = overlap / k
            recall_at_k.append(recall)
            for j in range(k):
                if (abs(ref_distances[i, j] - test_distances[i, j]) < 1e-6 and
                    ref_labels[i, j] == test_labels[i, j]):
                    exact_matches += 1
                min_distance_diff = float('inf')
                for p in range(k):
                    diff = abs(ref_distances[i, j] - test_distances[i, p])
                    if diff < min_distance_diff:
                        min_distance_diff = diff
                if min_distance_diff < 1e-6:
                    distance_matches += 1
                if ref_distances[i, j] != 0:
                    rel_error = min_distance_diff / ref_distances[i, j]
                    relative_distance_error.append(rel_error)
        exact_match_rate = exact_matches / total_comparisons * 100
        label_overlap_rate = label_overlaps / total_comparisons * 100
        distance_match_rate = distance_matches / total_comparisons * 100
        mean_recall = np.mean(recall_at_k) * 100
        mean_rel_error = np.mean(relative_distance_error) * 100 if relative_distance_error else 0
        return dict(
            exact_match_rate=exact_match_rate,
            label_overlap_rate=label_overlap_rate,
            distance_match_rate=distance_match_rate,
            mean_recall=mean_recall,
            mean_rel_error=mean_rel_error,
            recall_at_k=recall_at_k,
            relative_distance_error=relative_distance_error,
        )

    print("\nFlat vs Real:")
    flat_metrics = compute_metrics(real_distances, real_labels, flat_distances, flat_labels, k)
    print(f"Mean recall@{k}: {flat_metrics['mean_recall']:.2f}%")

    print("\nPanorama vs Real:")
    panorama_metrics = compute_metrics(real_distances, real_labels, panorama_distances, panorama_labels, k)
    print(f"Mean recall@{k}: {panorama_metrics['mean_recall']:.2f}%")


def run(
    d: int,
    nb: int,
    nq: int,
    M: int,
    num_threads: int,
    num_add_iterations: int = 1,
    num_search_iterations: int = 1,
    efSearch: int = 16,
    efConstruction: int = 40,
    search_bounded_queue: bool = True,
    only_query_idx: Optional[int] = None,
    levels: int = 16,
    epsilon: float = 1.0,
) -> Dict[str, int]:
    ds = SyntheticDataset(d=d, nb=nb, nt=0, nq=nq, metric="L2", seed=1338)
    counters, _ = run_on_dataset(
        ds,
        M=M,
        num_add_iterations=num_add_iterations,
        num_search_iterations=num_search_iterations,
        num_threads=num_threads,
        efSearch=efSearch,
        efConstruction=efConstruction,
        search_bounded_queue=search_bounded_queue,
        only_query_idx=only_query_idx,
        levels=levels,
        epsilon=epsilon,
    )
    return counters


def _accumulate_counters(
    element: Dict[str, int], accu: Optional[Dict[str, List[int]]] = None
) -> Dict[str, List[int]]:
    if accu is None:
        accu = {key: [value] for key, value in element.items()}
        return accu
    else:
        assert accu.keys() <= element.keys(), (
            "Accu keys must be a subset of element keys: "
            f"{accu.keys()} not a subset of {element.keys()}"
        )
        for key in accu.keys():
            accu[key].append(element[key])
        return accu


def main():
    parser = argparse.ArgumentParser(description="Benchmark HNSW")
    parser.add_argument("--M", type=int, default=32)
    parser.add_argument("--num-threads", type=int, default=5)
    parser.add_argument("--warm-up-iterations", type=int, default=0)
    parser.add_argument("--num-search-iterations", type=int, default=1)
    parser.add_argument("--num-add-iterations", type=int, default=1)
    parser.add_argument("--num-repetitions", type=int, default=1)
    parser.add_argument("--ef-search", type=int, default=16)
    parser.add_argument("--ef-construction", type=int, default=40)
    parser.add_argument("--search-bounded-queue", action="store_true")
    parser.add_argument("--levels", type=int, default=16)
    parser.add_argument("--epsilon", type=float, default=1.0)

    # Synthetic dataset arguments (used when --csv-file is not provided)
    parser.add_argument("--nb", type=int, default=5000)
    parser.add_argument("--nq", type=int, default=500)
    parser.add_argument("--d", type=int, default=128)
    
    # CSV dataset arguments
    parser.add_argument("--csv-file", type=str, help="Path to CSV file to use as dataset")
    parser.add_argument("--nc", type=int, help="Number of columns to read from CSV file (default: all columns)")
    parser.add_argument("--compare", action="store_true", help="Compare IndexHNSWPanorama vs IndexHNSWFlat")
    parser.add_argument("--only", type=int, help="If specified, only run the query for this index (disregards nq)")
    
    args = parser.parse_args()

    # Create dataset based on input arguments
    if args.csv_file:
        ds = CSVDataset(args.csv_file, nb=args.nb, nq=args.nq, nc=args.nc)
    else:
        ds = SyntheticDataset(d=args.d, nb=args.nb, nt=0, nq=args.nq, metric="L2", seed=1338)

    # Run comparison if requested
    if args.compare:
        compare_indices(
            ds=ds,
            M=args.M,
            num_threads=args.num_threads,
            num_add_iterations=args.num_add_iterations,
            num_search_iterations=args.num_search_iterations,
            efSearch=args.ef_search,
            efConstruction=args.ef_construction,
            search_bounded_queue=args.search_bounded_queue,
            only_query_idx=args.only,
            levels=args.levels,
            epsilon=args.epsilon,
        )
        return

    if args.warm_up_iterations > 0:
        print(f"Warming up for {args.warm_up_iterations} iterations...")
        # warm-up
        if args.csv_file:
            counters, _ = run_on_dataset(
                ds,
                M=args.M,
                num_threads=args.num_threads,
                num_add_iterations=args.warm_up_iterations,
                num_search_iterations=args.warm_up_iterations,
                efSearch=args.ef_search,
                efConstruction=args.ef_construction,
                search_bounded_queue=args.search_bounded_queue,
                only_query_idx=args.only,
                levels=args.levels,
                epsilon=args.epsilon,
            )
        else:
            counters = run(
                num_search_iterations=args.warm_up_iterations,
                num_add_iterations=args.warm_up_iterations,
                d=args.d,
                nb=args.nb,
                nq=args.nq,
                M=args.M,
                num_threads=args.num_threads,
                efSearch=args.ef_search,
                efConstruction=args.ef_construction,
                search_bounded_queue=args.search_bounded_queue,
                only_query_idx=args.only,
                levels=args.levels,
                epsilon=args.epsilon,
            )
    
    print(
        f"Running benchmark with dataset(nb={ds.nb}, nq={ds.nq}, "
        f"d={ds.d}), M={args.M}, num_threads={args.num_threads}, "
        f"efSearch={args.ef_search}, efConstruction={args.ef_construction}"
    )
    
    result = None
    for _ in range(args.num_repetitions):
        if args.csv_file:
            counters, _ = run_on_dataset(
                ds,
                M=args.M,
                num_threads=args.num_threads,
                num_add_iterations=args.num_add_iterations,
                num_search_iterations=args.num_search_iterations,
                efSearch=args.ef_search,
                efConstruction=args.ef_construction,
                search_bounded_queue=args.search_bounded_queue,
                only_query_idx=args.only,
                epsilon=args.epsilon,
            )
        else:
            counters = run(
                num_search_iterations=args.num_search_iterations,
                num_add_iterations=args.num_add_iterations,
                d=args.d,
                nb=args.nb,
                nq=args.nq,
                M=args.M,
                num_threads=args.num_threads,
                efSearch=args.ef_search,
                efConstruction=args.ef_construction,
                search_bounded_queue=args.search_bounded_queue,
                only_query_idx=args.only,
                levels=args.levels,
                epsilon=args.epsilon,
            )
        result = _accumulate_counters(counters, result)
    
    assert result is not None
    for counter, values in result.items():
        if is_perf_counter(counter):
            print(
                "%s t=%.3f us (± %.4f)" % 
                (counter, np.mean(values), np.std(values))
            )


if __name__ == "__main__":
    main()