import subprocess
import csv
import itertools
import os
import re

# Constants
BENCH_SCRIPT = "perf_tests/bench_hnsw.py"
CSV_FILE = "/mnt/datasets/gist/wavelets-gist.csv"
OUTPUT_CSV = "benchmark_results.csv"

# Hyperparameter grid
M_values = [16, 32, 64]
ef_search_values = [16, 32, 64, 128, 256, 512]
levels_values = [8, 10, 12, 14, 16, 18]
epsilon = 1.0

# Fixed dataset parameters
d = 960
nb = 100000
nq = 10000
common_args = [
    "--d", str(d),
    "--nb", str(nb),
    "--nq", str(nq),
    "--csv-file", CSV_FILE,
    "--compare",
    "--search-bounded-queue",
    "--epsilon", str(epsilon)
]

# Regex patterns for parsing output
patterns = {
    "flat_add_time": re.compile(r"Add time - Panorama: .*ms, Flat: ([\d.]+)ms"),
    "panorama_add_time": re.compile(r"Add time - Panorama: ([\d.]+)ms, Flat: .*ms"),
    "flat_search_time": re.compile(r"Search time - Panorama: .*ms, Flat: ([\d.]+)ms"),
    "panorama_search_time": re.compile(r"Search time - Panorama: ([\d.]+)ms, Flat: .*ms"),
    "flat_recall": re.compile(r"Flat vs Real:\n.*?Mean recall@10: ([\d.]+)%", re.DOTALL),
    "panorama_recall": re.compile(r"Panorama vs Real:\n.*?Mean recall@10: ([\d.]+)%", re.DOTALL),
}


def parse_output(output):
    result = {}
    for key, pattern in patterns.items():
        match = pattern.search(output)
        result[key] = float(match.group(1)) if match else None
    return result


def run_experiment(M, ef_search, levels):
    cmd = ["sudo", "python3", BENCH_SCRIPT] + common_args + [
        "--M", str(M),
        "--ef-search", str(ef_search),
        "--levels", str(levels),
    ]
    print(f"Running: M={M}, efSearch={ef_search}, levels={levels}")
    completed = subprocess.run(cmd, capture_output=True, text=True)

    if completed.returncode != 0:
        print(f"⚠️ Failed run: M={M}, efSearch={ef_search}, levels={levels}")
        print(completed.stderr)
        return None

    parsed = parse_output(completed.stdout)
    return {
        "M": M,
        "efSearch": ef_search,
        "levels": levels,
        "epsilon": epsilon,
        **parsed,
    }


def main():
    combinations = list(itertools.product(M_values, ef_search_values, levels_values))
    fieldnames = [
        "M", "efSearch", "levels", "epsilon",
        "flat_add_time", "panorama_add_time",
        "flat_search_time", "panorama_search_time",
        "flat_recall", "panorama_recall",
    ]

    with open(OUTPUT_CSV, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for M, ef, lvl in combinations:
            result = run_experiment(M, ef, lvl)
            if result:
                writer.writerow(result)

    print(f"\n✅ All benchmarks completed. Results saved to: {OUTPUT_CSV}")


if __name__ == "__main__":
    main()
