#!/usr/bin/env python3
"""
Enhanced benchmark runner with grid search over M and nprobe parameters.
This script:
1. Takes a dataset name as input
2. Finds all transformed and original files for that dataset
3. Runs grid search over M and nprobe parameters
4. Limits total runs to 100 for efficiency
5. Saves results to CSV with best settings identification
"""

import os
import subprocess
import re
import pandas as pd
import argparse
from datetime import datetime
import itertools
from typing import List, Tuple

# Simplified index types - using the same as the original
INDEX_TYPES = ['ivfpq', 'ivfpq_pano']

# Hardcoded parameters
NLEVELS = 8  # Fixed at 8 for cleaner code
NLIST = 10   # Fixed at 10 for cleaner code

def get_fvec_dimension_and_count(fvec_path):
    """Get dimension and vector count from a fvec file by reading the header."""
    try:
        with open(fvec_path, 'rb') as f:
            # Read first vector's dimension
            dim_bytes = f.read(4)
            if len(dim_bytes) != 4:
                return 0, 0
            dimension = int.from_bytes(dim_bytes, byteorder='little')
            
            # Calculate file size and estimate vector count
            f.seek(0, 2)  # Seek to end
            file_size = f.tell()
            
            # Each vector: 4 bytes (dimension) + dimension * 4 bytes (floats)
            bytes_per_vector = 4 + dimension * 4
            vector_count = file_size // bytes_per_vector
            
            return dimension, vector_count
    except Exception as e:
        print(f"Error reading fvec info from {fvec_path}: {e}")
        return 0, 0

def find_dataset_files(dataset_name, transformed_dir, original_dir):
    """
    Find all files for a specific dataset.
    
    Args:
        dataset_name: Name of the dataset (e.g., "sift")
        transformed_dir: Directory containing transformed files
        original_dir: Directory containing original files
    
    Returns:
        Dict with train_path, test_path, original_base_path, original_query_path
    """
    files = {
        'train_path': None,
        'test_path': None,
        'original_base_path': None,
        'original_query_path': None
    }
    
    # Find transformed files
    train_pattern = f"{dataset_name}_train_Cayley"
    test_pattern = f"{dataset_name}_test_Cayley"
    
    try:
        for f in os.listdir(transformed_dir):
            if f.startswith(train_pattern) and (f.endswith('.fvec') or f.endswith('.fvecs')):
                files['train_path'] = os.path.join(transformed_dir, f)
            elif f.startswith(test_pattern) and (f.endswith('.fvec') or f.endswith('.fvecs')):
                files['test_path'] = os.path.join(transformed_dir, f)
    except OSError as e:
        print(f"Error accessing transformed directory {transformed_dir}: {e}")
        return None
    
    # Find original files
    base_pattern = f"{dataset_name}_base"
    query_pattern = f"{dataset_name}_query"
    
    try:
        for f in os.listdir(original_dir):
            if f.startswith(base_pattern) and f.endswith('.fvec'):
                files['original_base_path'] = os.path.join(original_dir, f)
            elif f.startswith(query_pattern) and f.endswith('.fvec'):
                files['original_query_path'] = os.path.join(original_dir, f)
    except OSError as e:
        print(f"Error accessing original directory {original_dir}: {e}")
        return None
    
    # Check if all required files were found
    missing_files = [key for key, value in files.items() if value is None]
    if missing_files:
        print(f"Missing files for dataset '{dataset_name}': {missing_files}")
        return None
    
    return files

def run_benchmark(train_path, test_path, original_base_path, original_query_path, nb, nq, index_type, M, nprobe, epsilon, efSearch, work_dir):
    """Run a single benchmark test with appropriate file paths for each index type."""
    print(f"Running benchmark for {index_type} with nlevels={NLEVELS}, M={M}, nprobe={nprobe}")
    
    # For ivfpq_pano: use both original and transformed files
    # For ivfpq: use only original files (no trans paths)
    if 'pano' in index_type:
        print(f"  Original Base:     {original_base_path}")
        print(f"  Original Query:    {original_query_path}")
        print(f"  Transformed Train: {train_path}")
        print(f"  Transformed Test:  {test_path}")
        
        cmd = [
            'python3', 'faiss/perf_tests/bench_ivf.py',
            '--nb', str(nb),
            '--nq', str(nq),
            '--nprobe', str(nprobe),
            '--nlist', str(NLIST),
            '--num-threads', '1',
            '--nlevels', str(NLEVELS),
            '--M', str(M),
            '--epsilon', str(epsilon),
            '--csv-path', original_base_path,
            '--csv-test-path', original_query_path,
            '--trans-csv-path', train_path,
            '--trans-csv-test-path', test_path,
            '--index-type', index_type,
        ]
    else:
        print(f"  Original Base:     {original_base_path}")
        print(f"  Original Query:    {original_query_path}")
        print(f"  (No transformed files for non-pano index)")
        
        cmd = [
            'python3', 'faiss/perf_tests/bench_ivf.py',
            '--nb', str(nb),
            '--nq', str(nq),
            '--nprobe', str(nprobe),
            '--nlist', str(NLIST),
            '--num-threads', '1',
            '--nlevels', str(NLEVELS),
            '--M', str(M),
            '--epsilon', str(epsilon),
            '--csv-path', original_base_path,
            '--csv-test-path', original_query_path,
            '--index-type', index_type,
        ]
    
    try:
        print("---------------------------------------")
        print("Command to run manually:")
        print(" ".join(cmd))
        print("---------------------------------------")
        
        # Set PYTHONPATH to use locally built FAISS and Annoy
        env = os.environ.copy()
        local_faiss_path = work_dir + '/faiss/build/faiss/python/build/lib'
        local_annoy_path = work_dir + '/annoy/build/lib.linux-x86_64-cpython-312'
        env['PYTHONPATH'] = f"{local_annoy_path}:{local_faiss_path}:{env.get('PYTHONPATH', '')}"
        
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=work_dir, env=env)
        
        # Create command string for logging
        cmd_string = " ".join(cmd)
        
        if result.returncode == 0:
            parsed_result = parse_benchmark_output(result.stdout, index_type, train_path, nb, nq, M, nprobe, epsilon, efSearch)
            if parsed_result:
                # Add command to the results for logging
                parsed_result['command'] = cmd_string
            return parsed_result
        else:
            print(f"Benchmark failed for {index_type} with nlevels={NLEVELS}, M={M}, nprobe={nprobe}: {result.stderr}")
            return None
            
    except Exception as e:
        print(f"Exception during benchmark for {index_type} with nlevels={NLEVELS}, M={M}, nprobe={nprobe}: {e}")
        return None

def parse_benchmark_output(output, index_type, file_path, nb, nq, M, nprobe, epsilon, efSearch):
    """Parse the benchmark output to extract key metrics."""
    try:
        results = {
            'index_type': index_type,
            'file_path': file_path,
            'nb': nb,
            'nq': nq,
            'nlevels': NLEVELS,  # Always use the hardcoded value
            'M': M,
            'nprobe': nprobe,
            'epsilon': epsilon,
            'efSearch': efSearch,
            'wall_time_ms': None,
            'qps': None,
            'queries': None,
            'recall': None,
            'avg_level_percent': None,
        }
        
        # Parse wall time
        wall_time_match = re.search(r'Wall time: ([\d.]+) ms', output)
        if wall_time_match:
            results['wall_time_ms'] = float(wall_time_match.group(1))
        
        # Parse QPS
        qps_match = re.search(r'QPS: ([\d.]+)', output)
        if qps_match:
            results['qps'] = float(qps_match.group(1))
        
        # Parse queries
        queries_match = re.search(r'Queries: (\d+)', output)
        if queries_match:
            results['queries'] = int(queries_match.group(1))
        
        # Parse recall
        recall_match = re.search(r'Recall: ([\d.]+)', output)
        if recall_match:
            results['recall'] = float(recall_match.group(1))
        
        # Parse average level metrics for pano indexes
        if 'pano' in index_type:
            avg_level_match = re.search(r'avg_level %: ([\d.]+)%', output)
            if avg_level_match:
                results['avg_level_percent'] = float(avg_level_match.group(1))
            
            total_active_match = re.search(r'total_active: ([\d.]+)', output)
            if total_active_match:
                results['avg_level_percent'] = float(total_active_match.group(1)) * 100
            
            avg_level_active_match = re.search(r'Avg level active: ([\d.]+)', output)
            if avg_level_active_match:
                results['avg_level_percent'] = float(avg_level_active_match.group(1)) * 100
        
        return results
        
    except Exception as e:
        print(f"Error parsing benchmark output for {index_type} with nlevels={NLEVELS}, M={M}, nprobe={nprobe}: {e}")
        return None

def write_result_incremental(results_csv: str, result_row: dict) -> None:
    """Append a single result row to CSV, writing header if file does not exist."""
    try:
        row_df = pd.DataFrame([result_row])
        file_exists = os.path.exists(results_csv)
        row_df.to_csv(results_csv, mode='a', header=not file_exists, index=False)
    except Exception as exc:
        print(f"Warning: failed to append results to {results_csv}: {exc}")

def get_valid_m_values(dimension):
    """Get M values that are factors of dimension AND multiples of 8."""
    valid_m_values = []
    # Try multiples of 8 that also divide the dimension evenly
    for m in range(8, dimension + 1, 8):  # Start at 8, increment by 8
        if dimension % m == 0:
            valid_m_values.append(m)
    return valid_m_values

def get_nprobe_values(nlist=10):
    """Get reasonable nprobe values based on nlist."""
    # nprobe should be <= nlist and reasonable for search quality
    return [1, 2, 3, 5, 8, 10] if nlist >= 10 else list(range(1, nlist + 1))

def select_parameter_combinations(valid_m_values: List[int], nprobe_values: List[int], 
                                index_types: List[str], max_runs: int = 100) -> List[Tuple]:
    """
    Select parameter combinations to stay within max_runs limit.
    Prioritizes diverse coverage across parameter space.
    """
    # Count total possible combinations (simplified since nlevels is fixed)
    total_combos = len(index_types) * len(valid_m_values) * len(nprobe_values)
    
    print(f"Total possible combinations: {total_combos}")
    
    if total_combos <= max_runs:
        # Run all combinations
        combinations = []
        for index_type in index_types:
            for M, nprobe in itertools.product(valid_m_values, nprobe_values):
                combinations.append((index_type, M, nprobe))
        
        print(f"Running all {len(combinations)} combinations")
        return combinations
    
    # Need to sample combinations intelligently
    print(f"Sampling {max_runs} combinations from {total_combos} possible")
    
    # Strategy: Ensure good coverage across parameter space
    # 1. Balance M and nprobe exploration equally
    # 2. Sample key nprobe values (powers of 2: 1, 2, 4, 8, etc.)
    # 3. Sample fewer M values to balance with nprobe exploration
    # 4. Include both index types
    
    combinations = []
    runs_per_index = max_runs // len(index_types)
    
    for index_type in index_types:
        # Get key nprobe values (powers of 2)
        key_nprobes = []
        power = 0
        while True:
            nprobe_val = 2 ** power
            if nprobe_val > max(nprobe_values):
                break
            if nprobe_val in nprobe_values:
                key_nprobes.append(nprobe_val)
            power += 1
        
        # Ensure we have at least nprobe=1 if it's available
        if 1 not in key_nprobes and 1 in nprobe_values:
            key_nprobes.insert(0, 1)
        
        # Calculate how many M values we can afford given nprobe count
        # We want roughly equal exploration of both parameters
        total_combinations_per_index = runs_per_index
        target_m_count = int((total_combinations_per_index / len(key_nprobes)) ** 0.5)
        target_m_count = max(2, min(target_m_count, len(valid_m_values)))  # At least 2, at most all available
        
        # Sample M values evenly across the range
        if len(valid_m_values) <= target_m_count:
            selected_m = valid_m_values
        else:
            # Take evenly spaced M values including first and last
            indices = []
            if target_m_count > 1:
                step = (len(valid_m_values) - 1) / (target_m_count - 1)
                for i in range(target_m_count):
                    idx = int(i * step)
                    indices.append(idx)
            else:
                indices = [0]  # Just take the first one
            
            selected_m = [valid_m_values[i] for i in indices]
        
        print(f"  {index_type}: Will test {len(selected_m)} M values × {len(key_nprobes)} nprobe values = {len(selected_m) * len(key_nprobes)} combinations")
        
        # Create all combinations of selected M and nprobe values
        for M in selected_m:
            for nprobe in key_nprobes:
                combinations.append((index_type, M, nprobe))
                if len(combinations) >= max_runs:
                    break
            if len(combinations) >= max_runs:
                break
        if len(combinations) >= max_runs:
            break
    
    print(f"Selected {len(combinations)} combinations for grid search")
    
    # Print sampling summary
    print("\nSampling summary:")
    for index_type in index_types:
        type_combos = [(i, m, p) for i, m, p in combinations if i == index_type]
        if type_combos:
            m_values = sorted(list(set([m for _, m, _ in type_combos])))
            nprobe_values_used = sorted(list(set([p for _, _, p in type_combos])))
            print(f"  {index_type}: {len(type_combos)} runs, M={m_values}, nprobe={nprobe_values_used}")
    
    return combinations[:max_runs]

def find_best_settings(results_df: pd.DataFrame) -> dict:
    """Find best parameter settings using standard evaluation approaches."""
    if len(results_df) == 0:
        return {}
    
    best_settings = {}
    
    for index_type in results_df['index_type'].unique():
        type_results = results_df[results_df['index_type'] == index_type].copy()
        
        if len(type_results) == 0:
            continue
        
        # Standard approach: Find Pareto frontier points for QPS vs Recall trade-off
        # A point is Pareto optimal if no other point has both higher QPS AND higher recall
        
        # Find best recall
        best_recall = type_results.loc[type_results['recall'].idxmax()]
        
        # Find best QPS
        best_qps = type_results.loc[type_results['qps'].idxmax()]
        
        best_settings[index_type] = {
            'best_recall': {
                'M': int(best_recall['M']),
                'nprobe': int(best_recall['nprobe']),
                'nlevels': NLEVELS,
                'qps': float(best_recall['qps']),
                'recall': float(best_recall['recall'])
            },
            'best_qps': {
                'M': int(best_qps['M']),
                'nprobe': int(best_qps['nprobe']),
                'nlevels': NLEVELS,
                'qps': float(best_qps['qps']),
                'recall': float(best_qps['recall'])
            }
        }
    
    return best_settings

def main():
    # Generate timestamp for default filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    parser = argparse.ArgumentParser(description="Grid search over M and nprobe parameters for IVFPQ")
    parser.add_argument("--dataset-name", required=True,
                       help="Name of the dataset to benchmark (e.g., 'sift', 'glove')")
    parser.add_argument("--transformed-dir", default="/mnt/device/linear", 
                       help="Directory containing transformed fvec files")
    parser.add_argument("--original-dir", default="/mnt/device/datasets", 
                       help="Directory containing original fvec files")
    parser.add_argument("--output-dir", default="/mnt/device/linear", 
                       help="Directory for output files")
    parser.add_argument("--results-csv", default=f"grid_search_results_{timestamp}.csv", 
                       help="Output CSV file for grid search results")
    parser.add_argument("--nq", type=int, default=1000, 
                       help="Number of query points for benchmarks")
    parser.add_argument("--work-dir", default="/home/ubuntu/panorama",
                       help="Working directory for running benchmark commands")
    parser.add_argument("--max-runs", type=int, default=100,
                       help="Maximum number of benchmark runs")
    
    args = parser.parse_args()
    
    print(f"Grid search for dataset: {args.dataset_name}")
    print(f"Max runs: {args.max_runs}")
    print(f"Results will be saved to: {args.results_csv}")
    print(f"Using fixed nlevels: {NLEVELS}")
    
    # Find all files for the specified dataset
    dataset_files = find_dataset_files(args.dataset_name, args.transformed_dir, args.original_dir)
    if dataset_files is None:
        print(f"Could not find all required files for dataset '{args.dataset_name}'")
        return
    
    print(f"\nFound dataset files:")
    for key, path in dataset_files.items():
        print(f"  {key}: {path}")
    
    # Get dimension and vector count
    dimension, nb = get_fvec_dimension_and_count(dataset_files['train_path'])
    if dimension == 0 or nb == 0:
        print(f"Failed to read dimension/count from {dataset_files['train_path']}")
        return
    
    print(f"\nDataset info: {nb} vectors with dimension {dimension}")
    
    # Get valid parameter values
    valid_m_values = get_valid_m_values(dimension)
    nprobe_values = get_nprobe_values(nlist=NLIST)
    
    print(f"Valid M values for dimension {dimension}: {valid_m_values}")
    print(f"nprobe values to test: {nprobe_values}")
    
    if not valid_m_values:
        print(f"No valid M values found for dimension {dimension}")
        return
    
    # Select parameter combinations within max_runs limit
    combinations = select_parameter_combinations(valid_m_values, nprobe_values, INDEX_TYPES, args.max_runs)
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    all_results = []
    
    # Fixed parameters
    epsilon = 1.0  # Fixed  
    efSearch = 1  # Fixed
    
    # Run grid search
    print(f"\n{'='*80}")
    print(f"Starting grid search with {len(combinations)} parameter combinations")
    print(f"Fixed parameters: epsilon={epsilon}, nlist={NLIST}, nlevels={NLEVELS}")
    print(f"{'='*80}")
    
    for i, (index_type, M, nprobe) in enumerate(combinations, 1):
        print(f"\n[{i}/{len(combinations)}] Testing {index_type} - nlevels={NLEVELS}, M={M}, nprobe={nprobe}")
        
        result = run_benchmark(
            dataset_files['train_path'], 
            dataset_files['test_path'], 
            dataset_files['original_base_path'], 
            dataset_files['original_query_path'], 
            nb, 
            args.nq, 
            index_type, 
            M, 
            nprobe,
            epsilon, 
            efSearch, 
            args.work_dir
        )
        
        if result:
            # Add compression info for analysis
            result['compression_ratio'] = dimension / M  # Higher M = more compression
            result['bits_per_vector'] = M * 8 if M <= 64 else M * 16  # Estimate
            
            all_results.append(result)
            write_result_incremental(args.results_csv, result)
            print(f"✓ Success - QPS: {result.get('qps', 0):.1f}, Recall: {result.get('recall', 0):.3f}, Wall time: {result.get('wall_time_ms', 0):.1f}ms")
        else:
            print(f"✗ Failed")
    
    # Analyze results and find best settings
    print(f"\n{'='*80}")
    print(f"Grid search complete for dataset: {args.dataset_name}")
    print(f"Successful runs: {len(all_results)}/{len(combinations)}")
    print(f"Results saved to: {args.results_csv}")
    print(f"{'='*80}")
    
    if all_results:
        results_df = pd.DataFrame(all_results)
        best_settings = find_best_settings(results_df)
        
        print("\n" + "="*60)
        print("BEST PARAMETER SETTINGS")
        print("="*60)
        
        for index_type, settings in best_settings.items():
            print(f"\n{index_type.upper()}:")
            print("-" * 40)
            
            print("Best Recall:")
            best = settings['best_recall']
            print(f"  M={best['M']}, nprobe={best['nprobe']}, nlevels={best['nlevels']}")
            print(f"  QPS: {best['qps']:.1f}, Recall: {best['recall']:.3f}")
            
            print("Best QPS:")
            best = settings['best_qps']
            print(f"  M={best['M']}, nprobe={best['nprobe']}, nlevels={best['nlevels']}")
            print(f"  QPS: {best['qps']:.1f}, Recall: {best['recall']:.3f}")
        
        print(f"\nAll {len(all_results)} data points saved to: {args.results_csv}")
        print("Plot QPS vs Recall to see the full trade-off curve!")
        
        # Print summary statistics
        print(f"\n" + "="*60)
        print("GRID SEARCH SUMMARY")
        print("="*60)
        
        for index_type in INDEX_TYPES:
            type_results = results_df[results_df['index_type'] == index_type]
            if len(type_results) > 0:
                print(f"\n{index_type.upper()} - {len(type_results)} runs:")
                print(f"  QPS range: {type_results['qps'].min():.1f} - {type_results['qps'].max():.1f}")
                print(f"  Recall range: {type_results['recall'].min():.3f} - {type_results['recall'].max():.3f}")
                print(f"  M values tested: {sorted(type_results['M'].unique())}")
                print(f"  nprobe values tested: {sorted(type_results['nprobe'].unique())}")
        
        print(f"\nComplete results saved to: {args.results_csv}")
        print("Use this data to plot QPS vs Recall curves and select optimal parameters!")
        
    else:
        print("No successful benchmark runs!")

if __name__ == "__main__":
    main()