#!/usr/bin/env python3
"""
Comprehensive benchmark runner for all train fvec files.
This script:
1. Reads train fvec files directly
2. Runs benchmarks for all index types with varying nlevels
3. Collects results and saves them to a CSV file
"""

import os
import subprocess
import re
import pandas as pd
import time
from pathlib import Path
import argparse
from tqdm import tqdm
from datetime import datetime
import itertools
import time
import uuid


# Index types to test
# INDEX_TYPES = ['ivfpq', 'ivfpq_pano', 'ivf_flat', 'ivf_flat_pano', 'naive_pano', 'naive', 'annoy', 'annoy_pano', 'hnsw', 'hnsw_pano']
INDEX_TYPES = ['ivfpq', 'ivfpq_pano', 'ivf_flat', 'ivf_flat_pano', 'naive_pano', 'naive', 'annoy', 'annoy_pano', 'hnsw', 'hnsw_pano', 'mrpt', 'mrpt_pano']
# INDEX_TYPES = ['ivfpq', 'ivfpq_pano', 'ivf_flat', 'ivf_flat_pano', 'naive_pano', 'naive', 'annoy', 'annoy_pano']
# INDEX_TYPES = ['ivfpq', 'ivfpq_pano']
# INDEX_TYPES = ['mrpt', 'mrpt_pano']
# INDEX_TYPES = ['naive_pano', 'naive']
# INDEX_TYPES = ['ivf_flat', 'ivf_flat_pano', 'naive_pano', 'naive']
# INDEX_TYPES = ['annoy', 'annoy_pano', 'hnsw', 'hnsw_pano']


# IVF FLAT & IVFPQ PARAMS]
N_LIST = {
    "cifar10": [10], # 3072 dims
    "deep": [20],      # 96 dims
    "fashionmnist": [256], # 784
    "glove2m300": [512],
    "nytimes": [1024],
    "sift100m": [2048],
    "gist1m": [128], # 784 dims
}

N_PROBE = {
    "cifar10": [1, 5], # 3072 dims
    "deep": [1, 5],      # 96 dims
    "fashionmnist": [1, 10, 50], # 784
    "glove2m300": [1, 10, 200],
    "nytimes": [1, 10, 200],
    "sift100m": [1, 10, 100],
    "gist1m": [1, 5, 50], # 784 dims
}

# IVFPQ PARAMS
M_VALUES = {
    "cifar10": [1024, 1536], # 3072 dims
    "deep": [48, 96],      # 96 dims
    "fashionmnist": [392, 784], # 784
    "glove2m300": [150, 300],
    "nytimes": [128, 256],
    "sift100m": [64, 128],
    "gist1m": [320, 960], # 784 dims
}

N_LEVELS = {
    "cifar10": [1, 4, 8, 16],
    "deep": [1, 4, 8],
    "fashionmnist": [1, 4, 8, 16],
    "glove2m300": [1, 5],
    "nytimes": [1, 4, 8, 16],
    "sift100m": [1, 4, 8],
    "gist1m": [1, 4, 8, 16],
}

# HNSW PARAMS
M_HNSW = [8, 16]
EF_SEARCH = [32, 64, 128, 256, 512, 1024]

# ANNOY PARAMS
N_TREES = [100, 200]
SEARCK_K = [1000, 5000]

# MRPT PARAMS
RECALL_TARGET = [0.9, 0.95, 0.98]

# GENERAL PARAMS
K_VALUES = [10]



def get_fvec_info(fvec_path):
    """Get basic information about a fvec file."""
    try:
        # Use a simple command to get file size and estimate vector count
        result = subprocess.run(['wc', '-c', fvec_path], capture_output=True, text=True)
        if result.returncode == 0:
            size_bytes = int(result.stdout.strip().split()[0])
            # Rough estimate: each vector has 4 bytes for dimension + 4 bytes per float
            # We'll get the actual count when we transform
            return size_bytes
    except Exception as e:
        print(f"Error getting info for {fvec_path}: {e}")
    return 0

def get_fvec_dimension_and_count(fvec_path):
    """Get dimension and vector count from a fvec file by reading the header."""
    try:
        with open(fvec_path, 'rb') as f:
            # Read first vector's dimension
            dim_bytes = f.read(4)
            if len(dim_bytes) != 4:
                return 0, 0
            dimension = int.from_bytes(dim_bytes, byteorder='little')
            
            # Calculate file size and estimate vector count
            f.seek(0, 2)  # Seek to end
            file_size = f.tell()
            
            # Each vector: 4 bytes (dimension) + dimension * 4 bytes (floats)
            bytes_per_vector = 4 + dimension * 4
            vector_count = file_size // bytes_per_vector
            
            return dimension, vector_count
    except Exception as e:
        print(f"Error reading fvec info from {fvec_path}: {e}")
        return 0, 0



def get_valid_nlevels(dimension):
    """Get all valid nlevels values that divide the dimension evenly, between 3 and 20."""
    valid_levels = []
    for n in range(1, 13):
        if dimension % n == 0:
            valid_levels.append(n)
    return valid_levels

def find_original_files(fvecs_filename, original_dir):
    """
    Find original base and query files for a given transformed fvecs file.
    
    Args:
        fvecs_filename: e.g., "sift_train_panorama.fvecs"
        original_dir: Directory containing original fvec files
    
    Returns:
        Tuple of (base_file_path, query_file_path) or (None, None) if not found
    """
    # Split filename by _ and take first token (dataset name)
    basename = os.path.basename(fvecs_filename)
    parts = basename.split('_')
    if len(parts) < 2:
        print(f"Warning: Cannot parse dataset name from {basename}")
        return None, None
    
    dataset_name = parts[0]  # e.g., "sift"
    
    # Look for base file: {dataset_name}_base*.fvec (not plural)
    base_pattern = f"{dataset_name}_base"
    query_pattern = f"{dataset_name}_query"
    
    base_file = None
    query_file = None
    
    try:
        for f in os.listdir(original_dir):
            if f.startswith(base_pattern) and f.endswith('.fvec'):
                base_file = os.path.join(original_dir, f)
            elif f.startswith(query_pattern) and f.endswith('.fvec'):
                query_file = os.path.join(original_dir, f)
    except OSError as e:
        print(f"Error accessing original directory {original_dir}: {e}")
        return None, None
    
    if base_file and query_file:
        print(f"Found original files for {dataset_name}:")
        print(f"  Base:  {base_file}")
        print(f"  Query: {query_file}")
        return base_file, query_file
    else:
        print(f"Warning: Could not find original files for {dataset_name}")
        print(f"  Base file found: {base_file is not None}")
        print(f"  Query file found: {query_file is not None}")
        return None, None

def run_benchmark(train_path, test_path, original_base_path, original_query_path, nb, nq, index_type, nlevels, M, epsilon, efSearch, work_dir, M_hnsw, nlist, nprobe, n_trees, search_k, k, recall_target):
    """Run a single benchmark test."""
    print(f"Running benchmark for {index_type} with nlevels={nlevels}, M={M}, epsilon={epsilon}")
    print(f"  Transformed Train: {train_path}")
    print(f"  Transformed Test:  {test_path}")
    print(f"  Original Base:     {original_base_path}")
    print(f"  Original Query:    {original_query_path}")
    
    try:
        if 'pano' in index_type:
            cmd = [
                'python3', 'faiss/perf_tests/bench_ivf.py',
                '--nb', str(nb),
                '--nq', str(nq),
                '--nprobe', str(nprobe),
                '--nlist', str(nlist),
                '--num-threads', '1',
                '--nlevels', str(nlevels),
                '--M', str(M),
                '--epsilon', str(epsilon),
                '--n-trees', str(n_trees),
                '--search-k', str(search_k),
                '--csv-path', original_base_path,  # Original training data for ground truth
                '--csv-test-path', original_query_path,  # Original test data for queries
                '--trans-csv-path', train_path,  # Transformed training data for index
                '--trans-csv-test-path', test_path,  # Transformed test data for queries
                '--index-type', index_type,
                '--efSearch', str(efSearch),
                '--M-HNSW', str(M_hnsw),
                '--recall-target', str(recall_target),
                '--k', str(k)
            ]
        else:
            cmd = [
                'python3', 'faiss/perf_tests/bench_ivf.py',
                '--nb', str(nb),
                '--nq', str(nq),
                '--nprobe', str(nprobe),
                '--nlist', str(nlist),
                '--num-threads', '1',
                '--nlevels', str(nlevels),
                '--M', str(M),
                '--epsilon', str(epsilon),
                '--n-trees', str(n_trees),
                '--search-k', str(search_k),
                '--csv-path', original_base_path,  # Original training data for ground truth
                '--csv-test-path', original_query_path,  # Original test data for queries
                '--index-type', index_type,
                '--efSearch', str(efSearch),
                '--M-HNSW', str(M_hnsw),
                '--recall-target', str(recall_target),
                '--k', str(k)
            ]
        # Add caching for annoy-based indexes to reuse index across nlevels changes
        if 'annoy' in index_type or 'hnsw' in index_type:
            cmd.append('--cached')
        cmd.append('--cache-results')
        
        # Print the command for manual execution
        print("---------------------------------------")
        print("Command to run manually:")
        print(" ".join(cmd))
        print("---------------------------------------")
        
        # Set PYTHONPATH to use locally built FAISS and Annoy
        env = os.environ.copy()
        local_faiss_path = work_dir + '/faiss/build/faiss/python/build/lib'
        local_annoy_path = work_dir + '/annoy/build/lib.linux-x86_64-cpython-312'
        env['PYTHONPATH'] = f"{local_annoy_path}:{local_faiss_path}:{env.get('PYTHONPATH', '')}"
        
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=work_dir, env=env)
        
        if result.returncode == 0:
            return parse_benchmark_output(result.stdout, index_type, train_path, nb, nq, nlevels, epsilon, efSearch, " ".join(cmd), M_hnsw, M, nlist, nprobe, n_trees, search_k, k, recall_target)
        else:
            print(f"Benchmark failed for {index_type} with nlevels={nlevels}: {result.stderr}")
            return None
            
    except Exception as e:
        print(f"Exception during benchmark for {index_type} with nlevels={nlevels}: {e}")
        return None

def parse_benchmark_output(output, index_type, file_path, nb, nq, nlevels, epsilon, efSearch, command, M_hnsw, M, nlist, nprobe, n_trees, search_k, k, recall_target):
    """Parse the benchmark output to extract key metrics."""
    try:
        # Extract key metrics using regex
        results = {
            'index_type': index_type,
            'file_path': file_path,
            'nb': nb,
            'nq': nq,
            'nlevels': nlevels,
            'epsilon': epsilon,
            'wall_time_ms': None,
            'qps': None,
            'qps_mean': None,
            'qps_std': None,
            'search_mean_ms': None,
            'search_std_ms': None,
            'verification_mean_ms': None,
            'verification_std_ms': None,
            'queries': None,
            'recall': None,
            'avg_level_percent': None,
            'M': M,
            'nlist': nlist,
            'nprobe': nprobe,
            'n_trees': n_trees,
            'search_k': search_k,
            'k': k,
            'M_HNSW': M_hnsw,
            'efSearch': efSearch,
            'recall_target': recall_target,
            'command': command,
            'experiment_id': uuid.uuid4(),
            'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
        }
        
        # Parse wall time
        wall_time_match = re.search(r'Wall time: ([\d.]+) ms', output)
        if wall_time_match:
            results['wall_time_ms'] = float(wall_time_match.group(1))
        
        # Parse QPS - handle both old format (QPS: 1234.5) and new format (QPS: 1234.5 ± 67.8)
        qps_with_std_match = re.search(r'QPS: ([\d.]+) ± ([\d.]+)', output)
        if qps_with_std_match:
            # New format with mean and standard deviation
            results['qps_mean'] = float(qps_with_std_match.group(1))
            results['qps_std'] = float(qps_with_std_match.group(2))
            results['qps'] = results['qps_mean']  # Keep for backward compatibility
        else:
            # Fallback to old format
            qps_match = re.search(r'QPS: ([\d.]+)', output)
            if qps_match:
                results['qps'] = float(qps_match.group(1))
                results['qps_mean'] = results['qps']
                results['qps_std'] = -1
        
        # Parse queries
        queries_match = re.search(r'Queries: (\d+)', output)
        if queries_match:
            results['queries'] = int(queries_match.group(1))
        
        # Parse recall
        recall_match = re.search(r'Recall: ([\d.]+)', output)
        if recall_match:
            results['recall'] = float(recall_match.group(1))
        
        # Parse average level metrics - all are the same metric in different formats
        if 'pano' in index_type:
            total_active_match = re.search(r'avg_level: ([\d.]+)', output)
            if total_active_match:
                # Convert fraction to percentage (multiply by 100)
                results['avg_level_percent'] = float(total_active_match.group(1)) * 100
        
        # Parse all "Search time = {num}" and "Verification time = {num}" values
        search_times = []
        verification_times = []
        
        # Find all search time values
        search_pattern = r"Search time = ([\d.]+)"
        search_matches = re.findall(search_pattern, output)
        for match in search_matches:
            time_val = float(match)
            # Filter out extremely large values (likely overflow errors)
            if time_val < 1e10:  # Reasonable upper bound for milliseconds
                search_times.append(time_val)
        
        # Find all verification time values
        verification_pattern = r"Verification time = ([\d.]+)"
        verification_matches = re.findall(verification_pattern, output)
        for match in verification_matches:
            time_val = float(match)
            # Filter out extremely large values (likely overflow errors)
            if time_val < 1e10:  # Reasonable upper bound for milliseconds
                verification_times.append(time_val)
        
        # Calculate statistics if we have parsed timing data
        if search_times:
            import numpy as np
            
            # Convert to numpy arrays
            search_times_ms = np.array(search_times)
            search_times_s = search_times_ms / 1000.0  # Convert to seconds
            
            # Calculate QPS for each search time measurement
            qps_values = nq / search_times_s  # nq queries per search time
            
            # Calculate statistics using numpy
            results['search_mean_ms'] = np.mean(search_times_ms)
            results['search_std_ms'] = np.std(search_times_ms)
            results['qps_mean'] = np.mean(qps_values)
            results['qps_std'] = np.std(qps_values)
            
            # Override the QPS fields with our calculated values
            results['qps'] = results['qps_mean']
            
            # Calculate verification statistics if available
            if verification_times:
                verification_times_ms = np.array(verification_times)
                results['verification_mean_ms'] = np.mean(verification_times_ms)
                results['verification_std_ms'] = np.std(verification_times_ms)
        
        return results
        
    except Exception as e:
        print(f"Error parsing benchmark output for {index_type} with nlevels={nlevels}: {e}")
        return None

def write_result_incremental(results_csv: str, result_row: dict) -> None:
    """Append a single result row to CSV, writing header if file does not exist."""
    try:
        row_df = pd.DataFrame([result_row])
        file_exists = os.path.exists(results_csv)
        row_df.to_csv(results_csv, mode='a', header=not file_exists, index=False)
    except Exception as exc:
        print(f"Warning: failed to append results to {results_csv}: {exc}")



def main():
    # Generate timestamp for default filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    default_results_filename = f"benchmark_M_epsilon_results_{timestamp}.csv"
    
    parser = argparse.ArgumentParser(description="Run comprehensive benchmarks for all train fvec files with varying nlevels")
    parser.add_argument("--transformed-dir", default="/mnt/device/linear", 
                       help="Directory containing transformed fvec files")
    parser.add_argument("--original-dir", default="/mnt/device/datasets", 
                       help="Directory containing original fvec files")
    parser.add_argument("--output-dir", default="/mnt/device/transformed/temp", 
                       help="Directory for output files (currently unused)")
    parser.add_argument("--results-csv", default=default_results_filename, 
                       help="Output CSV file for results")
    parser.add_argument("--nq", type=int, default=100, 
                       help="Number of query points for benchmarks")
    parser.add_argument("--epsilon", type=float, default=1.0,
                       help="Pruning aggressiveness (passed through to bench_ivf)")
    parser.add_argument("--epsilon-list", type=str, default=None,
                       help="Comma-separated list of epsilon values to run sequentially (overrides --epsilon)")
    parser.add_argument("--work-dir", default="/home/name/panorama",
                       help="Working directory for running benchmark commands")
    
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Get all train fvec files and create train/test pairs
    train_test_pairs = []
    for f in os.listdir(args.transformed_dir):
        if 'train' in f and (f.endswith('.fvec') or f.endswith('.fvecs')):
            train_path = os.path.join(args.transformed_dir, f)
            # Create test filename by replacing '_train_' with '_test_'
            test_filename = f.replace('_train_', '_test_')
            test_path = os.path.join(args.transformed_dir, test_filename)
            
            # Check if test file exists
            if os.path.exists(test_path):
                train_test_pairs.append((train_path, test_path))
                print(f"Found train/test pair:")
                print(f"  Train: {train_path}")
                print(f"  Test:  {test_path}")
            else:
                print(f"Warning: Train file {train_path} found but corresponding test file {test_path} not found")
    
    print(f"\nFound {len(train_test_pairs)} train/test pairs:")
    
    all_results = []

    # Build list of epsilon values to run sequentially
    if args.epsilon_list:
        try:
            epsilon_values = [float(x.strip()) for x in args.epsilon_list.split(',') if x.strip() != '']
        except ValueError:
            print(f"Invalid --epsilon-list value: {args.epsilon_list}. Falling back to --epsilon {args.epsilon}.")
            epsilon_values = [args.epsilon]
    else:
        epsilon_values = [args.epsilon]

    total_benchmarks = 0

    print("IF YOU CHANGED DATASETS, PLEASE DELETE GROUND TRUTH CACHE!!!")
    time.sleep(5)
    
    # Process each train/test pair
    for train_path, test_path in tqdm(train_test_pairs, desc="Processing train/test pairs"):
        print(f"\n{'='*80}")
        print(f"Processing: {os.path.basename(train_path)} + {os.path.basename(test_path)}")
        print(f"{'='*80}")
        
        # Find corresponding original files
        original_base_path, original_query_path = find_original_files(train_path, args.original_dir)
        if original_base_path is None or original_query_path is None:
            print(f"Skipping {train_path} due to missing original files")
            continue
        
        # Get dimension and vector count directly from train fvec file
        dimension, nb = get_fvec_dimension_and_count(train_path)
        if dimension == 0 or nb == 0:
            print(f"Skipping {train_path} due to file reading failure")
            continue

        dataset_name = os.path.basename(train_path).split('_')[0]
        
        print(f"fvec contains {nb} vectors with dimension {dimension}, using {args.nq} queries")        
        # Run benchmarks for all epsilons and index types sequentially
        for index_type in INDEX_TYPES:

            # Use fixed nq=10 for naive_pano, otherwise use args.nq
            nq_to_use = 10 if index_type == 'naive_pano' or index_type == 'naive' else args.nq

            m_values = M_VALUES[dataset_name] if 'ivfpq' in index_type else [-1]
            ef_search_values = EF_SEARCH if 'hnsw' in index_type else [-1]
            M_hnsw_values = M_HNSW if 'hnsw' in index_type else [-1]
            nlist_values = N_LIST[dataset_name] if 'ivf' in index_type else [-1]
            nprobe_values = N_PROBE[dataset_name] if 'ivf' in index_type else [-1]
            n_trees_values = N_TREES if 'annoy' in index_type else [-1]
            search_k_values = SEARCK_K if 'annoy' in index_type else [-1]
            recall_target_values = RECALL_TARGET if 'mrpt' in index_type else [-1]
            k_values = K_VALUES
            level_values = N_LEVELS[dataset_name] if 'pano' in index_type else [-1]

            # print lenght of all lists
            num_bench = len(list(itertools.product(m_values, ef_search_values, M_hnsw_values, nlist_values, nprobe_values, n_trees_values, search_k_values, k_values, level_values, recall_target_values)))
            print("Running ", num_bench, " benchmarks for ", dataset_name, index_type, "total benchmarks: ", total_benchmarks)
            total_benchmarks += num_bench
            # continue
            epsilon = 1.0

            new_nb = nb if 'annoy' not in index_type and 'hnsw' not in index_type and 'mrpt' not in index_type else min(nb, 10_000_000)
            
            # iterate over cache-killing params
            for (M_hnsw, n_trees, search_k) in itertools.product(M_hnsw_values, n_trees_values, search_k_values):
                # clear cache
                print("Clearing /tmp of .ann and .index cache files...")
                subprocess.run("rm -f /tmp/*.ann /tmp/*.index", shell=True)                

                # iterate over cache-friendly params
                for (k, M, ef_search, nlist, nprobe, recall_target, nlevels) in itertools.product(k_values, m_values, ef_search_values, nlist_values, nprobe_values, recall_target_values, level_values):
                    eps_values = [1]
                    print(f"\n--- Benchmarking {index_type} with nlevels={nlevels}, epsilon={epsilon}, M={M}, ef_search={ef_search}, M_hnsw={M_hnsw}, nlist={nlist}, nprobe={nprobe}, n_trees={n_trees}, search_k={search_k}, k={k} ---")
                    result = run_benchmark(
                        train_path, test_path, original_base_path, original_query_path, new_nb, nq_to_use, index_type, nlevels, M, epsilon, ef_search, args.work_dir, M_hnsw, nlist, nprobe, n_trees, search_k, k, recall_target
                    )
                    if result:
                        all_results.append(result)
                        write_result_incremental(args.results_csv, result)
                        print(f"Results for {index_type} with nlevels={nlevels}, epsilon={epsilon}, M={M}:")
                        for key, value in result.items():
                            if key not in ['file_path']:
                                print(f"  {key}: {value}")
                    else:
                        print(f"Benchmark failed for {index_type} with nlevels={nlevels}, epsilon={epsilon}, M={M}")
    
    # Save all results to CSV (final write; file already has incremental rows)
    if all_results:
        results_df = pd.DataFrame(all_results)
        # Ensure header present; append would have created file already
        if not os.path.exists(args.results_csv):
            results_df.to_csv(args.results_csv, index=False)
        print(f"\n{'='*80}")
        print(f"All results saved to: {args.results_csv}")
        print(f"Total benchmark runs: {len(all_results)}")
        print(f"{'='*80}")
        
        # Print summary
        print("\nSummary of results:")
        print(results_df.groupby(['index_type', 'nlevels']).agg({
            'wall_time_ms': ['mean', 'std'],
            'qps': ['mean', 'std'],
            'recall': ['mean', 'std']
        }).round(2))
    else:
        print("No results to save!")

if __name__ == "__main__":
    main()
