import pickle
import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os
import sys

# ==============================================================
# LOG SETUP  （stdout + stderr → log file）
# ==============================================================

log_name = "approx_join_k20_m_sweep.log"
log_file = open(log_name, "w")

class Tee:
    def __init__(self, *files):
        self.files = files

    def write(self, data):
        for f in self.files:
            f.write(data)
            f.flush()

    def flush(self):
        for f in self.files:
            f.flush()

sys.stdout = Tee(sys.stdout, log_file)
sys.stderr = Tee(sys.stderr, log_file)

print(f"[LOG] Writing output to {log_name}")

# ==============================================================
# Dataset configurations
# ==============================================================

DATA_DIR = "/path/to/dataset"
GT_DIR   = "/path/to/groundtruth"

dataset_paths = {
    "landmark": {
        "bin_file": f"{DATA_DIR}/landmark-dino-768-cosine.bin",
        "gt_file": f"{GT_DIR}/landmark_Cosine_k_100_indices.npy",
    }
}


# ==============================================================
# Fixed experiment parameters
# ==============================================================

method = "approx_join"
D = 512
iProbe = 3
qProbe = 1
dist = "Cosine"
verbose = True
n_threads = 32

# ✅ 只运行一次
trial_seed = 100

# ✅ 固定 k = 50
k = 20

# ✅ m sweep
topm_list = [5, 10, 20, 25, 50, 100]

# NNDescent iterations
iters = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

# repeats
n_repeats_list = [4]

# ==============================================================
# Main Loop: datasets × parameters
# ==============================================================

print(f"\n\n============ Running single seed {trial_seed} ============\n")

for dataset_name, data_info in dataset_paths.items():

    print(f"\n\n======== Dataset: {dataset_name} | Seed={trial_seed} ========\n")

    bin_file = data_info["bin_file"]
    n = data_info["n"]
    d = data_info["d"]
    gt_file = data_info["gt_file"]

    # ---------------- Load dataset ----------------
    if dataset_name == "Word":
        X = np.load(bin_file)
    else:
        X = utils.mmap_bin(bin_file, n, d)

    # ---------------- Normalize (cosine only) ----------------
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms[norms == 0] = 1
    X = X / norms

    # ---------------- Output directory ----------------
    output_dir = f"{dataset_name}_result/Sensitivity"
    os.makedirs(output_dir, exist_ok=True)

    csv_file = f"{output_dir}/{method}_k20_seed{trial_seed}.csv"

    # Write header
    if not os.path.isfile(csv_file):
        with open(csv_file, mode="w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "Dataset", "Method", "Init Time(ms)",
                "Construction distance cost per point",
                "NNDescent Time(ms)", "Distance computation",
                "Total Time(ms)", "NNDescent Iterations",
                "NNDescent Accuracy", "Number of random vector",
                "TopK", "iProbe", "top_m", "qProbe",
                "distance", "centering"
            ])

    exact_kNN = np.load(gt_file)[:, 1:k+1]

    # ---------------- Parameter sweep ----------------
    for n_repeats in n_repeats_list:
        for top_m in topm_list:

            print(f"\n>> n_repeats={n_repeats}, top_m={top_m}")

            index = FalconnLite.FalconnLite(n, d)
            index.set_params(
                n_proj=D,
                iProbe=iProbe,
                top_m=top_m,
                qProbe=qProbe,
                distance=dist,
                verbose=verbose,
                n_threads=n_threads,
                seed=-1
            )
            index.centering = True

            # ---------------- Init phase ----------------
            t_init_start = timeit.default_timer()
            indices, _ = index.approx_join(X, topK=k, n_repeats=n_repeats)
            t_init_end = timeit.default_timer()
            init_time_ms = (t_init_end - t_init_start) * 1000

            # ---------------- NNDescent Refinement ----------------
            for t in iters:
                t_nn_start = timeit.default_timer()

                index1 = NNDescent(
                    X,
                    n_neighbors=k,
                    random_state=trial_seed,
                    tree_init=False,
                    init_graph=indices,
                    metric="cosine",
                    n_iters=t,
                    n_jobs=n_threads
                )
                knn_indices, _ = index1.neighbor_graph

                t_nn_end = timeit.default_timer()
                nn_time_ms = (t_nn_end - t_nn_start) * 1000
                total_time_ms = init_time_ms + nn_time_ms

                acc = utils.getAcc(knn_indices, exact_kNN)

                with open(csv_file, mode="a", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        dataset_name,
                        f"{n_repeats} {method}",
                        f"{init_time_ms:.2f}",
                        0,
                        f"{nn_time_ms:.2f}",
                        getattr(index1, "counter", 0) / n,
                        f"{total_time_ms:.2f}",
                        t,
                        f"{acc:.4f}",
                        D,
                        k,
                        iProbe,
                        top_m,
                        qProbe,
                        dist,
                        index.centering
                    ])

    print(f"\n---- Finished {dataset_name} ----\n")

print("\n[LOG] Experiment finished.")
log_file.close()
