import pickle
import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os
import sys

# ==============================================================
# LOG SETUP  （stdout + stderr → log file）
# ==============================================================

log_name = "approx_join_k20_D_sweep.log"
log_file = open(log_name, "w")

class Tee:
    def __init__(self, *files):
        self.files = files

    def write(self, data):
        for f in self.files:
            f.write(data)
            f.flush()

    def flush(self):
        for f in self.files:
            f.flush()

sys.stdout = Tee(sys.stdout, log_file)
sys.stderr = Tee(sys.stderr, log_file)

print(f"[LOG] Writing output to {log_name}")

# ==============================================================
# Dataset configurations
# ==============================================================

DATA_DIR = "/path/to/dataset"
GT_DIR   = "/path/to/groundtruth"

dataset_paths = {
    "landmark": {
        "bin_file": f"{DATA_DIR}/landmark-dino-768-cosine.bin",
        "gt_file": f"{GT_DIR}/landmark_Cosine_k_100_indices.npy",
    }
}

# ==============================================================
# Fixed experiment parameters
# ==============================================================

method = "approx_join"
qProbe = 1
dist = "Cosine"
verbose = True
n_threads = 32

# ✅ single seed
trial_seed = 100

# ✅ fixed k and m
k = 20
top_m = 25

# ✅ D sweep
D_list = [256, 512, 1024]

# ✅ fixed 8 tables
n_repeats_list = [4]

# ✅ fixed iProbe (no sweep)
iprobe_list = [3]

# NNDescent iterations (for convergence curve)
iters = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

# ==============================================================
# Main Loop
# ==============================================================

print(f"\n\n============ Running single seed {trial_seed} ============\n")

for dataset_name, data_info in dataset_paths.items():

    print(f"\n\n======== Dataset: {dataset_name} | Seed={trial_seed} ========\n")

    bin_file = data_info["bin_file"]
    n = data_info["n"]
    d = data_info["d"]
    gt_file = data_info["gt_file"]

    # ---------------- Load dataset ----------------
    X = utils.mmap_bin(bin_file, n, d)

    # ---------------- Normalize ----------------
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms[norms == 0] = 1
    X = X / norms

    # ---------------- Output directory ----------------
    output_dir = f"{dataset_name}_result/Sensitivity"
    os.makedirs(output_dir, exist_ok=True)

    csv_file = f"{output_dir}/{method}_k20_D_sweep_seed{trial_seed}.csv"

    # Write CSV header
    if not os.path.isfile(csv_file):
        with open(csv_file, mode="w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([
                "Dataset",
                "Method",
                "Init Time(ms)",
                "Construction distance cost per point",
                "NNDescent Time(ms)",
                "Distance computation",
                "Total Time(ms)",
                "NNDescent Iterations",
                "NNDescent Accuracy",
                "Number of random vector",
                "TopK",
                "iProbe",
                "top_m",
                "qProbe",
                "distance",
                "centering"
            ])

    exact_kNN = np.load(gt_file)[:, 1:k+1]

    # ==========================================================
    # Parameter sweep
    # ==========================================================

    for n_repeats in n_repeats_list:
        for D in D_list:
            for iProbe in iprobe_list:

                print(f"\n>> n_repeats={n_repeats}, D={D}, iProbe={iProbe}")

                index = FalconnLite.FalconnLite(n, d)
                index.set_params(
                    n_proj=D,
                    iProbe=iProbe,
                    top_m=top_m,
                    qProbe=qProbe,
                    distance=dist,
                    verbose=verbose,
                    n_threads=n_threads,
                    seed=-1
                )
                index.centering = True

                # ---------------- Init phase ----------------
                t_init_start = timeit.default_timer()
                indices, _ = index.approx_join(
                    X,
                    topK=k,
                    n_repeats=n_repeats
                )
                t_init_end = timeit.default_timer()
                init_time_ms = (t_init_end - t_init_start) * 1000

                # ---------------- NNDescent Refinement ----------------
                for t in iters:
                    t_nn_start = timeit.default_timer()

                    index1 = NNDescent(
                        X,
                        n_neighbors=k,
                        random_state=trial_seed,
                        tree_init=False,
                        init_graph=indices,
                        metric="cosine",
                        n_iters=t,
                        n_jobs=n_threads
                    )
                    knn_indices, _ = index1.neighbor_graph

                    t_nn_end = timeit.default_timer()
                    nn_time_ms = (t_nn_end - t_nn_start) * 1000
                    total_time_ms = init_time_ms + nn_time_ms

                    acc = utils.getAcc(knn_indices, exact_kNN)

                    with open(csv_file, mode="a", newline="") as f:
                        writer = csv.writer(f)
                        writer.writerow([
                            dataset_name,
                            f"{n_repeats} {method}",
                            f"{init_time_ms:.2f}",
                            0,
                            f"{nn_time_ms:.2f}",
                            getattr(index1, "counter", 0) / n,
                            f"{total_time_ms:.2f}",
                            t,
                            f"{acc:.4f}",
                            D,
                            k,
                            iProbe,
                            top_m,
                            qProbe,
                            dist,
                            index.centering
                        ])

    print(f"\n---- Finished {dataset_name} ----\n")

print("\n[LOG] Experiment finished.")
log_file.close()
