import numpy as np
import h5py
import subprocess
import struct
import os

def read_fvecs(filename):
    """Robust reader for .fvecs format."""
    with open(filename, "rb") as f:
        dim = np.fromfile(f, dtype=np.int32, count=1)[0]
        f.seek(0)
        data = np.fromfile(f, dtype=np.float32)
    assert data.size % (dim + 1) == 0, "Invalid .fvecs file"
    return data.reshape(-1, dim + 1)[:, 1:]

def write_fvecs(filename, data):
    """Write NumPy array to .fvecs format."""
    num_vecs, dim = data.shape
    with open(filename, 'wb') as f:
        for vec in data:
            f.write(struct.pack('i', dim))
            f.write(struct.pack(f'{dim}f', *vec))

def convert_fvecs_to_hdf5(fvecs_path, hdf5_path, dataset_name="train"):
    data = read_fvecs(fvecs_path)
    print(f"[INFO] Loaded {data.shape[0]} vectors of dimension {data.shape[1]}")
    with h5py.File(hdf5_path, "w") as f:
        f.create_dataset(dataset_name, data=data, dtype='float32')
    print(f"[INFO] Written HDF5 to: {hdf5_path}")

def generate_queries_single(hdf5_input, hdf5_output, k, rc, n_queries):
    cmd = [
        "hephaestus",
        "--dataset", hdf5_input,
        "--output", hdf5_output,
        "--distance", "euclidean",
        "-k", str(k),
        "--queries", f"{n_queries}:{rc}"
    ]
    print(f"[INFO] Running Hephaestus with RC={rc}...")
    subprocess.run(cmd, check=True)
    print(f"[INFO] Queries saved to: {hdf5_output}")

def extract_queries_to_fvecs(hdf5_query_file, output_fvecs_file):
    with h5py.File(hdf5_query_file, 'r') as f:
        queries = f["test"][:]
    print(f"[INFO] Extracted {queries.shape[0]} queries of dim {queries.shape[1]}")
    write_fvecs(output_fvecs_file, queries)
    print(f"[INFO] Written .fvecs to: {output_fvecs_file}")

# === Main ===
if __name__ == "__main__":
    input_fvecs = "/mnt/device/datasets/gist_base.fvecs"  # Input
    hdf5_dataset = "dataset.hdf5"
    k = 10
    n_queries = 1000

    # Difficulty configs: (RC value, output file)
    difficulties = [
        (1.0, "queries_easy.fvecs"),
        (2.0, "queries_medium.fvecs"),
        (3.0, "queries_hard.fvecs"),
    ]

    # Step 1: Convert input FVECS to HDF5 if not already done
    if not os.path.exists(hdf5_dataset):
        convert_fvecs_to_hdf5(input_fvecs, hdf5_dataset)
    else:
        print(f"[INFO] Using cached dataset: {hdf5_dataset}")

    # Step 2: Generate queries per difficulty
    for rc_value, output_fvecs in difficulties:
        hdf5_queries_file = f"tmp_queries_rc{rc_value}.hdf5"
        generate_queries_single(hdf5_dataset, hdf5_queries_file, k, rc_value, n_queries)
        extract_queries_to_fvecs(hdf5_queries_file, output_fvecs)
        os.remove(hdf5_queries_file)  # clean up temp file
