#!/usr/bin/env python3
"""
MRPT Example: QPS and Recall Measurement

This example demonstrates MRPT usage with random data and measures:
- QPS (Queries Per Second) for approximate nearest neighbor search
- Recall by comparing approximate results with exact search

The script generates random data, builds an autotuned MRPT index, 
runs queries, and reports performance metrics.
"""

import time
import numpy as np
import mrpt

def calculate_recall(exact_neighbors, approx_neighbors):
    """
    Calculate recall for a set of queries.
    
    Args:
        exact_neighbors: Ground truth neighbors (n_queries x k)
        approx_neighbors: Approximate neighbors (n_queries x k)
    
    Returns:
        recall: Average recall across all queries
    """
    if exact_neighbors.ndim == 1:
        # Single query case
        exact_set = set(exact_neighbors)
        approx_set = set(approx_neighbors)
        return len(exact_set.intersection(approx_set)) / len(exact_set)
    
    # Multiple queries case
    total_recall = 0.0
    n_queries = exact_neighbors.shape[0]
    
    for i in range(n_queries):
        exact_set = set(exact_neighbors[i])
        approx_set = set(approx_neighbors[i])
        query_recall = len(exact_set.intersection(approx_set)) / len(exact_set)
        total_recall += query_recall
    
    return total_recall / n_queries

def main():
    print("MRPT Example: QPS and Recall Measurement")
    print("=" * 50)
    
    # Dataset parameters
    n_samples = 10000  # Number of data points
    dim = 128         # Dimension of each data point
    k = 10            # Number of nearest neighbors to find
    n_queries = 1000  # Number of test queries
    target_recall = 0.95  # Target recall for autotuning
    
    print(f"Dataset: {n_samples} points, {dim} dimensions")
    print(f"Queries: {n_queries} test queries, k={k}")
    print(f"Target recall: {target_recall}")
    print()
    
    # Generate random data
    print("Generating random data...")
    np.random.seed(42)
    data = np.random.rand(n_samples, dim).astype(np.float32)
    queries = np.random.rand(n_queries, dim).astype(np.float32)
    
    print("Building MRPT index...")
    start_time = time.time()
    
    # Create MRPT index
    index = mrpt.MRPTIndex(data, 4, 1, True)
    
    # Build autotuned index with fixed seed for reproducibility
    # Remove seed parameter for random results each run
    index.build_autotune_sample(target_recall, k)
    
    build_time = time.time() - start_time
    print(f"Index built in {build_time:.2f} seconds")
    
    # Get index parameters
    params = index.parameters()
    print(f"Index parameters: {params}")
    print()
    
    # Measure approximate search QPS
    print("Measuring approximate search performance...")
    start_time = time.time()
    
    approx_results = []
    for query in queries:
        neighbors = index.ann(query, return_distances=False)
        approx_results.append(neighbors)
    
    approx_time = time.time() - start_time
    qps = n_queries / approx_time
    
    print(f"Approximate search: {1000 * approx_time:.4f} milliseconds for {n_queries} queries")
    print(f"QPS (Queries Per Second): {qps:.2f}")
    print()
    
    # Measure exact search time for comparison
    print("Running exact search for recall calculation...")
    start_time = time.time()
    
    exact_results = []
    for query in queries:
        neighbors = index.exact_search(query, k, return_distances=False)
        exact_results.append(neighbors)
    
    exact_time = time.time() - start_time
    exact_qps = n_queries / exact_time
    
    print(f"Exact search: {1000 * exact_time:.4f} milliseconds for {n_queries} queries")
    print(f"Exact search QPS: {exact_qps:.2f}")
    print()
    
    # Calculate recall
    print("Calculating recall...")
    approx_results = np.array(approx_results)
    exact_results = np.array(exact_results)
    
    recall = calculate_recall(exact_results, approx_results)
    
    # Print results summary
    print("=" * 50)
    print("RESULTS SUMMARY")
    print("=" * 50)
    print(f"Dataset size: {n_samples:,} points × {dim} dimensions")
    print(f"Number of queries: {n_queries:,}")
    print(f"k (neighbors): {k}")
    print(f"Target recall: {target_recall:.1%}")
    print()
    print(f"Index build time: {build_time:.2f} seconds")
    print(f"Index parameters:")
    for key, value in params.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
    print()
    print(f"PERFORMANCE METRICS:")
    print(f"  QPS (Approximate): {qps:.2f}")
    print(f"  QPS (Exact):       {exact_qps:.2f}")
    print(f"  Speedup:           {qps/exact_qps:.1f}x")
    print(f"  Actual Recall:     {recall:.1%}")
    print(f"  Recall vs Target:  {recall/target_recall:.1%}")
    print()
    
    if recall >= target_recall * 0.95:  # Within 5% of target
        print("✅ SUCCESS: Achieved target recall!")
    else:
        print("⚠️  WARNING: Recall below target")
    
    print(f"💡 MRPT achieved {qps:.0f} QPS with {recall:.1%} recall")
    
    # Print detailed timing statistics if panorama mode was used
    print("\n" + "="*50)
    print("DETAILED TIMING STATISTICS")
    print("="*50)
    index.print_times()

if __name__ == "__main__":
    main()
